Plotting metrics at the end of episode on Wandb

Hi Everyone,

I’m running ray.tune with PBT and DQN on a custom environment, which calculates many metrics at the end of each episode. Given that the default Wandb callback plot the metrics at the end of each step for each trial, which consists of many episodes, I was wondering what is a recommended way to plot the metrics at the end of each episode.

More specifically, I have a custom callback based on the Rllib DefaultCallback to calculate the metrics on at the end of each episode. Since all the metrics on a step by step basis in an episode is available there, how can I plot those values, e.g: histograms/plots, on Wandb while running ray tune trials?

Thanks in advance!

P.s: if this question should be moved to Rllib instead, please let me know!

Hi @NumberChiffre, we don’t support this out of the box.

Could you try something like this:

from ray.tune.registry import get_trainable_cls
from ray.tune.integration.wandb import WandbTrainableMixin

trainable_cls = get_trainable_cls("PPO")

class _WrappedTrainable(WandbTrainableMixin, trainable_cls):
    _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
        else "wrapped_trainable"
    # ...

Then you should be able to just use wandb.log() in your rllib callback.

I haven’t tried this before, so let me know if this doesn’t work and I’m happy to look into it sometime.

Thanks for your reply @kai, I tried a few things:

Given I still use WandBLogger in just to display the metrics every step in the trial, do I need to do something like wandb.init() in the Trainable class? Also, I’m not sure how to pass the wandb objects in trainable onto the Rllib Callable class.

So then I would do the same for every call to on_episode_start() from Rllib Callable. If I run multiple trials for tuning, I can just add a counter to only run wandb.init() for every start of the episode. I just don’t know how to do it through WandbLoggerCallback, it seems log_trial_start() just starts after on_episode_start(), so that’s why I’m adding wandb.init() in the Rllib Callback. Also, in this case there is no need to wrap around a Trainable class.

I definitely would like to find a way to use wandb.init() at the start of each new trial, so that I can start using wandb.log() in on_episode_start(), this would at least make sure the episode metrics and trial metrics all end up at the same place. What would you do in this case?

Thanks for taking the time, let me know what you think!

Using the Mixin would automatically run wandb.init() in the trainables. You can then use the top level wandb.log() method to log metrics directly to wandb.

In that case you probably wouldn’t want to use the WandbLoggerCallback anymore.

Does this make sense?

I just tried what you said like this:

 class _WrappedTrainable(WandbTrainableMixin, trainable_cls):
        _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
            else "wrapped_trainable"

        def __init__(self, config=None, logger_creator=None):
            super().__init__(config=config, logger_creator=logger_creator)

However, the rllib trainer class will raise exception since an unknown param is passed, which is ‘wandb’ which is a dictionary that holds the api_key/project/group info to be used for wand.init(). If one doesn’t pass the ‘wandb’ dict, then wand.init() would not work.

If I were to use WandbLoggerCallback, I still need to call wandb.init(). I used store/get trial_id information between workers for their rllib callbacks to plot, it works on my side but I definitely would like to see how you would work it out just using _WrappedTrainable.


Not sure if this is the most efficient way, but I’m doing this the following way to plot episode data:

  • Wrap around WandbLoggerCallback so you can pass around params such as trial.trial_id onto a so you can pass onto a rllib callback to call wandb.init()
  • To get the same log directory as the wandb.init() called in WandbLoggerCallback, the name argument follows pattern from WANDB_GROUP + ENV_NAME + TRIAL_ID with underscore in between each variable
  • For each Worker passed to rllib callback, the logdir is the same from trial.logdir, so this can be used as a key to find the passed for each trial to initialize wandb.init(), you only need to call once
  • You can freely log at the end of each episode, you can also add a counter to the so you can keep track of the number of episodes to log, or you can log every step in the episode and still call custom charts on wandb

Don’t hesitate to share your way!