Value function of recurrent state models

Hi, I’ve been working on a RNN style model with PPO. I’ve run into this error:

   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 303, in postprocess_trajectory
     return postprocess_fn(self, sample_batch,
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/postprocessing.py", line 174, in compute_gae_for_sample_batch
    last_r = policy._value(**input_dict)
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 220, in value
     model_out, _ = self.model(input_dict)
   File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
     return forward_call(*input, **kwargs)
TypeError: forward() missing 2 required positional arguments: 'hidden_states' and 'seq_lens'

It seems to be caused by this definition of value() function here. It only passes the observations from a sample batch, but not the hidden states. It looks like only observations are available at this scope, not the hidden states.

I don’t think this is expected behavior? I can use the init state when hidden states are not provided, but I would prefer to have the actual states just for a better estimate of value.

@Aceticia,

What is the type of self.model? It should be TorchModelV2 but based on your stack trace it does not look like it is.

My model is indeed a subclass of TorchModelV2. Are you suggesting normally it shouldn’t have got here?

Do you have a reproduction script available?

Here is a shortest script:

import ray
import gym
from ray import tune

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2


class Agent(nn.Module, TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs,
            model_config, name, **customized_model_kwargs):

        nn.Module.__init__(self)
        TorchModelV2.__init__(
            self,
            name=name,
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config
        )
        hiddens = 32
        self.num_outputs = num_outputs
        self.hiddens = hiddens
        
        # Dummy module so that we have parameters to learn
        self.lin = nn.Linear(5, 5)

    def forward(self, input_dict, hidden_states, seq_lens):
        self._feat_size = input_dict['obs'].shape[0]
        return torch.zeros(self._feat_size, self.num_outputs), self.get_initial_state()

    def value_function(self):
        return torch.zeros(self._feat_size)

    def get_initial_state(self):
        return [torch.zeros(1, self.hiddens)]


if __name__ == '__main__':
    from ray.rllib.models import ModelCatalog

    run_config = {
        "env": "BreakoutNoFrameskip-v4",
        "model": {
            "custom_model": "agent",
        },
        "framework": "torch",
        "num_gpus": 1,
        "num_workers": 0,
    }

    ModelCatalog.register_custom_model("agent", Agent)
    ray.init()
    results = tune.run("PPO", config=run_config, verbose=1)
    ray.shutdown()

@Aceticia

Try swapping the order of inheritance

Agent(TorchModelV2,nn.Module):

2 Likes

Now it works perfectly! Thank you!