Using Hydra for configuring a custom torch model + RLlib´s PPO

Dear all!

My goal is to use Hydra to configure the network I used for training. Basically, I use RLlib’s PPO to train the continuous mountain car as well as a custom torch model. Actually, the custom model, that is derived from *TorchModelV2, makes it difficult for me to use hydra to configure the network. Since I follow the pattern of the directory structure recommended by the Hydra homepage, my python scripts and yaml-files have to following structure:

  • config

    • model
      • torch_model_v01.yaml
    • config.yaml
  • src

    • models
      • components

I developed the following python script to train the continuous mountain car:

import os
from datetime import date
import numpy as np
import tempfile
from omegaconf import DictConfig, OmegaConf
from typing import Optional
import hydra
import ray
from ray.rllib.algorithms import ppo
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger, DEFAULT_LOGGERS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.misc import normc_initializer as normc_init_torch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from gymnasium.envs.classic_control.continuous_mountain_car import Continuous_MountainCarEnv
from src.models.torch_model import TorchModelV01

torch, nn = try_import_torch()


class MountainCar(Continuous_MountainCarEnv):
    def __init__(self, config):
        self.config = config

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        return super().reset()

    def step(self, action):
        return super().step(action)

def custom_logger_creator():
    timestr ="%Y-%m-%d_%H-%M-%S")
    logdir_prefix = "{}_{}_{}".format("PPO", "MountainCar", timestr)
    def logger_creator(config):

        if not os.path.exists(DEFAULT_RESULTS_DIR):
        logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
        loggers = list(DEFAULT_LOGGERS)
        return UnifiedLogger(config, logdir, loggers=loggers)
    return logger_creator

@hydra.main(version_base=None, config_path="../configs", config_name="config")
def train(cfg: DictConfig) -> Optional[float]:

    ppo_config = PPOConfig()

    ModelCatalog.register_custom_model("torch_model", TorchModelV01)

    config = (
        .rollouts(num_rollout_workers=10, num_envs_per_worker=1, batch_mode="complete_episodes")
                "custom_model": "torch_model",
                "custom_model_config": {"marvin": 42},
    algo =

    for n in range(10000000):
        results = algo.train()
        checkpoint_dir =
        print(f"Checkpoint saved in directory {checkpoint_dir}")

if __name__ == "__main__":

The torch model is derived from TorchModelV2:

from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override

torch, nn = try_import_torch()

class TorchModelV01(TorchModelV2):

    def __init__(self, obs_space, action_space, num_outputs, model_config, name,  critic: torch.nn.Module, actor: torch.nn.Module):
        super.__init__(self, obs_space, action_space, num_outputs, model_config, name)

        self.critic = critic = actor

    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"]
        v_values = self.critic(x)
        self.activations_layer_1 = list(self.critic.parameters())[1]
        self.v_values = v_values
        mean_std =
        return mean_std, state

    def value_function(self):
        return torch.reshape(self.v_values, [-1])

    def metrics(self):
        return {"activations_layer_1": list((self.activations_layer_1).detach().numpy())}

To avoid a huge mess I just want to show here the critic network - the actor looks similar:

from ray.rllib.utils.framework import try_import_tf, try_import_torch
torch, nn = try_import_torch()

class Critic(nn.Module):
    def __init__(self, hidden_size: int = 64, num_inputs: int = 2, num_outputs: int = 1):

        hidden_size = 64
        num_inputs = 2
        num_outputs = 1

        self.critic = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.Linear(hidden_size, num_outputs)

The config.yaml file:

  - _self_
  - model: torch_model_v01.yaml

The torch_model_v01.yaml:

_target_: src.models.torch_model.TorchModelV01

  _target_: src.models.components.critic.Critic
  hidden_size: 64
  num_inputs: 2
  num_outputs: 1

  hidden_size: 64
  num_inputs: 2
  num_outputs: 1

Actually I don´t know how to proceed or rather I don´t know the syntax for the custom model together with the PPO configuration.

I would be more than happy if someone can give me some hints or any advice because currently I am stuck here.

Kind regards,

Hi @MRMarlies ,

I can not advise on how to use hydra with RLlib as I am not familiar with it.
You can find examples on custom models here.

If you feel adventurous, you can install a nightly version of Ray and try out our RLModules API.