Advanced evaluation with wandb, RLlib and Tune (weight, gradient, activation histogram)

Hey all,

I was just wondering if there is any way to get weight, gradient, activation histograms working with the Weights and Biases callback for Tune and the barebone PPO algorithm (Torch)?

One of the main parts of training deep RL models, is keeping the DNNs healthy used for function approximation, so I was just suprised that there is no easy way right now to get this working.

I though about two approaches.

  1. A custom callback:
class MyCallback(DefaultCallbacks):
    def on_train_result(self, trainer: Trainer, result: dict, **kwargs):
        for layer_name, value in trainer.get_policy().get_weights().items():
            result[layer_name] = wandb.Histogram(value)
  1. Accessing the Torch model directly and using the normal wandb callback:
# This is no actual code, but rather pseudo code.
pytorch_model = trainer.get_model(), log_freq=100)

But both are not working due to some limitations of Ray. In the first one, wandb.Histogram is not an allowed data type, so it is ignored and not reported to the wandb dashboard.

The second one is also not working, because you can’t use wandb methods, as they conflict with the WandbLoggerCallback of tune (to my knowledge).

I would highly appreciate any feedback if this is possible and if there are plans to make this work seamlessly.

PS: Other custom metrics are working like a charm, thanks Ray team.


An official guide how to use with RLlib models would be super cool!