- High: It blocks me to complete my task.
I am trying to use a custom metrics to create a stop condition. Basically what I am trying to do is to stop a training a PPO agent if the evaluation reward mean has plateaued. I created a custom callback below
import numpy as np
from ray.tune import Callback
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.algorithm import Algorithm
from utils import plateau_detection
class PlateauDetectionCallback(Callback):
def __init__(self, evaluation_interval, window_size=5, metric="episode_reward_mean", std=0.001, grace_period=10):
self.metric = metric
self.std = std
self.window_size = window_size
self.grace_period = grace_period
self.eval_interval = evaluation_interval
self.eval_rewards = []
self.iter = 0
def on_trial_result(self, iteration, trials, trial, result, **info):
#print("result: ", result)
if self.metric not in result:
return
if self.iter > 0 and self.iter % self.eval_interval == 0:
print("result: ", result)
eval_reward_mean = result['evaluation']['sampler_results']['episode_reward_mean']
if self.iter >= self.grace_period and eval_reward_mean is not None:
self.eval_rewards.append(eval_reward_mean)
# Check for plateau if we have enough data points
if len(self.eval_rewards) >= self.window_size:
is_plateau = plateau_detection(self.eval_rewards, self.window_size, self.std)
# Set the custom metric
result["eval_plateau"] = is_plateau
# if self.iter > 1:
# #result["eval_plateau"] = True
# print("result: ", result)
else:
print(("RESULT: ", result))
result["eval_plateau"] = False
self.iter += 1
print("result: ", result)
#print("result: ", result)
and created the stop condition using
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"eval_plateau": True,
}
config.framework(args.framework)
# Log results using WandB.
tune_callbacks = []
tune_callbacks.append(
WandbLoggerCallback(
project=args.wandb_project,
upload_checkpoints=True,
log_config=True,
config=args,
job_type="train",
name=run_name,
group=args.wandb_run_group,
sync_tensorboard=True,
monitor_gym=True,
save_code=True
),
)
tune_callbacks.append(PlateauDetectionCallback(evaluation_interval=args.evaluation_interval))
# Force Tuner to use old progress output as the new one silently ignores our custom
# `CLIReporter`.
#os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
# Run the actual experiment (using Tune).
results = Tuner(
config.algo_class,
param_space=config.to_dict(),
run_config=RunConfig(
stop=stop,
verbose=args.verbose,
callbacks=tune_callbacks,
checkpoint_config=CheckpointConfig(
num_to_keep=1,
checkpoint_score_attribute="episode_reward_mean",
checkpoint_score_order="max",
checkpoint_at_end=args.checkpoint_at_end,
),
),
).fit()
All the processes (driver and the roller out worker) are dying after the first evaluation without entering the callback