Yet another question on RNN sequencing

Hello,

I am running the following example with PDB:

import argparse
import os
import numpy as np
import gym
from gym.spaces import Box, Discrete, MultiDiscrete
from typing import Dict, List, Union
from gym.envs.classic_control import CartPoleEnv

import ray
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

class StatelessCartPole(CartPoleEnv):
    """Partially observable variant of the CartPole gym environment.

    https://github.com/openai/gym/blob/master/gym/envs/classic_control/
    cartpole.py

    We delete the x- and angular velocity components of the state, so that it
    can only be solved by a memory enhanced model (policy).
    """

    def __init__(self, config=None):
        super().__init__()

        # Fix our observation-space (remove 2 velocity components).
        high = np.array(
            [
                self.x_threshold * 2,
                self.theta_threshold_radians * 2,
            ],
            dtype=np.float32)

        self.observation_space = Box(low=-high, high=high, dtype=np.float32)

    def step(self, action):
        next_obs, reward, done, info = super().step(action)
        # next_obs is [x-pos, x-veloc, angle, angle-veloc]
        return np.array([next_obs[0], next_obs[2]]), reward, done, info

    def reset(self):
        init_obs = super().reset()
        # init_obs is [x-pos, x-veloc, angle, angle-veloc]
        return np.array([init_obs[0], init_obs[2]])

class TorchRNNModel(TorchRNN, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 fc_size=64,
                 lstm_state_size=256):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = fc_size
        self.lstm_state_size = lstm_state_size

        # Build the Module from fc + LSTM + 2xfc (action + value outs).
        self.fc1 = nn.Linear(self.obs_size, self.fc_size)
        self.lstm = nn.LSTM(
            self.fc_size, self.lstm_state_size, batch_first=True)
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None

    @override(ModelV2)
    def get_initial_state(self):
        # TODO: (sven): Get rid of `get_initial_state` once Trajectory
        #  View API is supported across all of RLlib.
        # Place hidden states on same device as model.
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        import pdb; pdb.set_trace()
        x = nn.functional.relu(self.fc1(inputs))
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0),
                torch.unsqueeze(state[1], 0)])
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

ray.init(local_mode=True)

ModelCatalog.register_custom_model("rnn", TorchRNNModel)
register_env("StatelessPendulum", lambda _: StatelessCartPole())

config = {
    "env": "StatelessPendulum",
    "env_config": {
        "repeat_delay": 2,
    },
    "gamma": 0.9,
    # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
    "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
    "num_workers": 0,
    "num_envs_per_worker": 20,
    "entropy_coeff": 0.001,
    "num_sgd_iter": 5,
    "vf_loss_coeff": 1e-5,
    "simple_optimizer": True,
    "sgd_minibatch_size": 128,
    "model": {
        "custom_model": "rnn",
        "max_seq_len": 20,
        "custom_model_config": {
            "cell_size": 32,
        },
    },
    "framework": "torch",
}

stop = {
    "training_iteration": 10,
    "timesteps_total": 1e6,
    "episode_reward_mean": 500.,
}

results = tune.run('PPO', config=config, stop=stop, verbose=1)
ray.shutdown()

the following is the result of the first cycles:

as you can see at the same breakpoint I am getting:

torch.Size([32, 1, 2])
torch.Size([1, 1, 2])
torch.Size([4, 8, 2])
torch.Size([20, 1, 2])
torch.Size([20, 1, 2])

So what is the reason of this different sizes?
And also since _time_major is false by default I would expect [bacth,timesteps,features] format and that should be [32,20,2].
The version I am using is 2.0.0dev installed today.

@mg64ve

Why are you expecting it to be 32,20,2?

@mannyv I have seen this 32 many times and I suppose this is a batch size but I am not sure. Instead if batch size is 1, since _time_major is False by default I would expect [1,20,2] since max_seq_len is 20. Does it make sense?

Hi @mg64ve

There are 3 (sometimes 4) distinct phases that the model is called in.

Your debugging has revealed 2 of them.

These 3 are from the initialization phase. The y are taken from a dummy batch of all zeros. I ignore these unless they are generatimg an error.

torch.Size([32, 1, 2])
torch.Size([1, 1, 2])
torch.Size([4, 8, 2])

These are from the rollout phase when compute actions is called to sanple new trajectories from the environment. Your config has 20 envs per worker each of which is taking 1 step.

torch.Size([20, 1, 2])
torch.Size([20, 1, 2])

After you collect 4000 steps the training phase will run. You have not reported that phase but when you hit it it will have shape [num_episodes,max_seq,2]. The max_seq is dynamic by default so if you did not have an episode that lasted 20 steps then it will be shorter than that.

Happy New Year


Thanks for your reply @mannyv , I hope you enjoined some coffee :slight_smile:
I have some doubt on a your sentence:

After you collect 4000 steps the training phase will run.

This is forward_rnn method and it contains model layers that are processed every time this method is called. So the training should be at each time the method is called. Is my assumption wrong?
Secondly, today I have changed the forward_rnn method to:

@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
    """Feeds `inputs` (B x T x ..) through the Gru Unit.

    Returns the resulting outputs as a sequence (B x T x ...).
    Values are stored in self._cur_value in simple (B) shape (where B
    contains both the B and T dims!).

    Returns:
        NN Outputs (B x T x ...) as sequence.
        The state batches as a List of two items (c- and h-states).
    """
    B = inputs.shape[0]
    x = nn.functional.relu(self.fc1(inputs))
    self._features, [h, c] = self.lstm(
        x, [torch.unsqueeze(state[0], 0),
            torch.unsqueeze(state[1], 0)])
    action_out = self.action_branch(self._features)
    return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

and I have manually started pdb with command line:


python3 -m pdb custom_rnn_model.py

then I have set a conditional breakpoint on line 111 which is the next to B definition:

(Pdb) b 111, B > 3999

PDB never stops, it goes to the end of the program. The same happens if I put:

(Pdb) b 111, B > 100

It seems it does not reach [4000,20,2] as you mentioned. What do you think?

Hi @mg64ve,

The batch size will not be 4000 (that is not a magic number that is PPOs default value for train_batch_size) . Once it has collected those timesteps it does not train them all at the same time. It will make num_sgd_iter (default:30) updates on a sub-batch of size sgd_minibatch_size (default:128) timesteps.

Your best bet to hit that breakpoint is to add the line: T=inputs.shape[1] and set a breakpoint when T>1.

Yes, you are right @mannyv . I have just put a print and it works, I am attaching the first 400 lines.
The problem is I want to have a single LSTM layer and on top of that logits and value MLP.
According to pytorch documentation the shape of the inputs for LSTM should be [batch,timesteps,features] and the LSTM should be defined as LSTM(timesteps,hidden_cells).
Now if I define LSTM(20,128) this would fail if it gets [20,1,2] input because it expects [something,20,2].
How can I solve this problem if I do not want any layer before LSTM?

input_shape:32x1x2
input_shape:1x1x2
input_shape:4x8x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:1x1x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:7x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:8x19x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:8x19x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:7x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:7x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:8x19x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:7x20x2
input_shape:8x19x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:7x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:10x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:11x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:8x19x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:9x20x2
input_shape:10x20x2
input_shape:9x20x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2
input_shape:20x1x2

Hi @mg64ve,

The lstm constructor takes the input feature size and the hidden size. Your feature size is 2 so your lstm module should be LSTM(2,128).

Look at the example in the documentation. The input size (10) matches the feature dim (3) not the time dim (2)

>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))