Stop criteria using a custom metric

  • High: It blocks me to complete my task.

I am trying to use a custom metrics to create a stop condition. Basically what I am trying to do is to stop a training a PPO agent if the evaluation reward mean has plateaued. I created a custom callback below


import numpy as np
from ray.tune import Callback


from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.algorithm import Algorithm

from utils import plateau_detection

class PlateauDetectionCallback(Callback):
    def __init__(self, evaluation_interval, window_size=5, metric="episode_reward_mean", std=0.001, grace_period=10):
        self.metric = metric
        self.std = std
        self.window_size = window_size
        self.grace_period = grace_period
        self.eval_interval = evaluation_interval
        self.eval_rewards = []
        self.iter = 0

    def on_trial_result(self, iteration, trials, trial, result, **info):
        #print("result: ", result)
        if self.metric not in result:
            return
        if self.iter > 0 and self.iter % self.eval_interval == 0:
            print("result: ", result)
            eval_reward_mean = result['evaluation']['sampler_results']['episode_reward_mean']
            if self.iter >= self.grace_period and eval_reward_mean is not None:
                self.eval_rewards.append(eval_reward_mean)
                # Check for plateau if we have enough data points
                if len(self.eval_rewards) >= self.window_size:
                    is_plateau = plateau_detection(self.eval_rewards, self.window_size, self.std)
                    
                    # Set the custom metric
                    result["eval_plateau"] = is_plateau

            # if self.iter > 1:

            #     #result["eval_plateau"] = True
            #     print("result: ", result)
        else:
            print(("RESULT: ", result))
            result["eval_plateau"] = False

        self.iter += 1
        print("result: ", result)
        #print("result: ", result)            

and created the stop condition using

 stop = {
        "training_iteration": args.stop_iters,
        "timesteps_total": args.stop_timesteps,
        "eval_plateau": True,
    }

    config.framework(args.framework)

     # Log results using WandB.
    tune_callbacks = []
    
    tune_callbacks.append(
        WandbLoggerCallback(
                project=args.wandb_project,
                upload_checkpoints=True,
                log_config=True,
                config=args,
                job_type="train",
                name=run_name,
                group=args.wandb_run_group,
                sync_tensorboard=True,
                monitor_gym=True,
                save_code=True
                ),

        )
    tune_callbacks.append(PlateauDetectionCallback(evaluation_interval=args.evaluation_interval))

      # Force Tuner to use old progress output as the new one silently ignores our custom
    # `CLIReporter`.
    #os.environ["RAY_AIR_NEW_OUTPUT"] = "0"

    # Run the actual experiment (using Tune).
    results = Tuner(
        config.algo_class,
        param_space=config.to_dict(),
        run_config=RunConfig(
            stop=stop,
            verbose=args.verbose,
            callbacks=tune_callbacks,
            checkpoint_config=CheckpointConfig(
                num_to_keep=1,
                checkpoint_score_attribute="episode_reward_mean",
                checkpoint_score_order="max",
                checkpoint_at_end=args.checkpoint_at_end,
            ),
        ),
    ).fit()


All the processes (driver and the roller out worker) are dying after the first evaluation without entering the callback

if a comment out eval_plateau in stop.

 stop = {
        "training_iteration": args.stop_iters,
        "timesteps_total": args.stop_timesteps,
      #  "eval_plateau": True,
    }

Things work and I can see the added key eval_plateau in result keys. but I won’t be able to use it as a stop criteria.

Hi @smallsuper,

What does the results dictionary look like? What are the keys?