### Search before asking
- [X] I searched the [issues](https://github.com/ray…-project/ray/issues) and found no similar issues.
### Ray Component
RLlib
### What happened + What you expected to happen
I would expect given two sequences `A, B`:
`[A, A, A, B, B]; seq_lens=[3, 2], obs.shape = [5, 1]`
would be padded to
`[A, A, A, B, B, *]; seq_lens=[3, 2], obs.shape = [2, 3, 1]`
This does not appear to be the case. For some reason rllib zero-pads `obs` to something besides `seq_lens.max()`. Even more worrisome is calling `torch.nonzero()` on the `input_dict`, which shows front-padded zeros to the observations. For example, printing `input_dict['obs'].reshape(B, T, -1) == 0` results in:
```
(PPO pid=74137) [[ True, True],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [False, False],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [ True, True],
(PPO pid=74137) [ True, True]],
```
The zero-padding is clearly messed up, the first five observations have been zero-padded and then we have real observations offset by five.
### Versions / Dependencies
Linux Ray 1.7.0
### Reproduction script
Feel free to play with the `USE_CORRECT_SHAPE` flag
```python
import torch
import numpy as np
import gym
from typing import Union, Dict, List, Tuple, Any
import ray
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.tune import register_env
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
# Pad to the correct size and crash
# or follow the rnn_sequencing code and don't crash
USE_CORRECT_SHAPE = False
class TestRNN(TorchModelV2, torch.nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
**custom_model_kwargs,
):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
torch.nn.Module.__init__(self)
self.num_outputs = num_outputs
self.input_dim = gym.spaces.utils.flatdim(obs_space)
self.act_space = action_space
self.act_dim = gym.spaces.utils.flatdim(action_space)
self.cur_val = None
self.policy = torch.nn.Linear(self.input_dim, self.act_dim)
self.vf = torch.nn.Linear(self.input_dim, 1)
def get_initial_state(self):
return [torch.zeros(0)]
def value_function(self):
assert self.cur_val is not None, "must call forward() first"
return self.cur_val
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
flat = input_dict["obs_flat"]
if USE_CORRECT_SHAPE:
max_seq_len = seq_lens.max()
else:
# max_seq_len here is copied from rllib RNN code
# see https://github.com/ray-project/ray/blob/2d24ef0d3234867ac329b10ae3a11b9b7119d17b/rllib/models/torch/recurrent_net.py#L75
# but it doesn't make sense...
# it should be max_seq_len = seq_len.max()
max_seq_len = flat.shape[0] // seq_lens.shape[0]
padded = add_time_dimension(
flat,
max_seq_len=max_seq_len,
framework="torch",
time_major=False
)
B = padded.shape[0]
T = padded.shape[1]
# If this fails, then we have "extra" padding in the RNN
# We shouldn't need to pad the time dimension more than the longest
# sequence
if seq_lens.max() != T:
print(f'seq_lens.max() is {seq_lens.max()} but input temporal dim is {T}')
print(flat.reshape(B, T, -1) == 0)
raise Exception('seq_len mismatch')
flattened = padded.reshape(-1, padded.shape[-1])
logits = self.policy(flattened)
self.cur_val = self.vf(flattened).squeeze(1)
state = state
return logits, state
register_env(StatelessCartPole.__name__, StatelessCartPole)
MAX_SEQ_LEN = 200
CFG = {
"env_config": {},
"framework": "torch",
"model": {
"custom_model": TestRNN,
"max_seq_len": MAX_SEQ_LEN,
},
"num_workers": 0,
"num_gpus": 0,
"env": StatelessCartPole,
"horizon": MAX_SEQ_LEN,
}
ray.init(object_store_memory=3e10)
analysis = ray.tune.run(
PPOTrainer,
config=CFG,
)
```
### Anything else
Every train step
### Are you willing to submit a PR?
- [ ] Yes I am willing to submit a PR!