Value function of recurrent state models

Hi, I’ve been working on a RNN style model with PPO. I’ve run into this error:

   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/policy/", line 303, in postprocess_trajectory
     return postprocess_fn(self, sample_batch,
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/", line 174, in compute_gae_for_sample_batch
    last_r = policy._value(**input_dict)
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/agents/ppo/", line 220, in value
     model_out, _ = self.model(input_dict)
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/torch/nn/modules/", line 1051, in _call_impl
     return forward_call(*input, **kwargs)
TypeError: forward() missing 2 required positional arguments: 'hidden_states' and 'seq_lens'

It seems to be caused by this definition of value() function here. It only passes the observations from a sample batch, but not the hidden states. It looks like only observations are available at this scope, not the hidden states.

I don’t think this is expected behavior? I can use the init state when hidden states are not provided, but I would prefer to have the actual states just for a better estimate of value.


What is the type of self.model? It should be TorchModelV2 but based on your stack trace it does not look like it is.

My model is indeed a subclass of TorchModelV2. Are you suggesting normally it shouldn’t have got here?

Do you have a reproduction script available?

Here is a shortest script:

import ray
import gym
from ray import tune

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class Agent(nn.Module, TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs,
            model_config, name, **customized_model_kwargs):

        hiddens = 32
        self.num_outputs = num_outputs
        self.hiddens = hiddens
        # Dummy module so that we have parameters to learn
        self.lin = nn.Linear(5, 5)

    def forward(self, input_dict, hidden_states, seq_lens):
        self._feat_size = input_dict['obs'].shape[0]
        return torch.zeros(self._feat_size, self.num_outputs), self.get_initial_state()

    def value_function(self):
        return torch.zeros(self._feat_size)

    def get_initial_state(self):
        return [torch.zeros(1, self.hiddens)]

if __name__ == '__main__':
    from ray.rllib.models import ModelCatalog

    run_config = {
        "env": "BreakoutNoFrameskip-v4",
        "model": {
            "custom_model": "agent",
        "framework": "torch",
        "num_gpus": 1,
        "num_workers": 0,

    ModelCatalog.register_custom_model("agent", Agent)
    results ="PPO", config=run_config, verbose=1)


Try swapping the order of inheritance



Now it works perfectly! Thank you!