Shape mismatch during training in LSTM PPO model (RLlib new API)

I’m using LSTM + PPO in RLlib’s new API stack. I’m running into a problem inside def _compute_embeddings_and_state_outs(self, batch). During rollout, everything works fine. But when the data is passed to the learner for training, the LSTM input (obs) no longer matches the size of state_in['h'] and state_in['c'].

From what I understand, RLlib automatically splits episodes into fixed-length sequences (e.g., using max_seq_len) and adjusts state_in, actions, rewards, etc., accordingly — but it does not split batch["obs"], which still contains the full uncut episode. This mismatch causes shape errors during training.

Here is my _compute_embeddings_and_state_outs() function:

def _compute_embeddings_and_state_outs(self, batch):
obs = batch[“obs”][“_orig_obs”][…, :self.state_size]
context = batch[“obs”][“_orig_obs”][…, self.state_size:]
action = batch[“obs”][“prev_n_actions”].squeeze(-2)
reward = batch[“obs”][“prev_n_rewards”].squeeze(-1).unsqueeze(-1)

    input = torch.cat([obs, context, action, reward], dim=-1)

    h = batch["state_in"]["h"].unsqueeze(0)  # [1, B, H]
    c = batch["state_in"]["c"].unsqueeze(0)  # [1, B, H]
    lstm_output, (h, c) = self.lstm(input, (h, c))

    embeddings = self.hidden(torch.cat([lstm_output, obs], dim=-1))
    return embeddings, {'h': h.squeeze(0), 'c': c.squeeze(0)} 

For example, I set:

batch_mode = “complete_episodes”
train_batch_size_per_learner = 120
episode_length = 50

For the features in batch:

Key: obs
This is a nested dict:
_orig_obs: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 13])
prev_n_actions: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 1, 6])
prev_n_rewards: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 1])

Key: state_in
This is a nested dict:
h: type=<class ‘torch.Tensor’>, shape=torch.Size([9, 64])
c: type=<class ‘torch.Tensor’>, shape=torch.Size([9, 64])

and the error will be:
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (1, 3, 64), got [1, 9, 64]
Error: tcpip::Socket::recvAndCheck @ recv: peer shutdown49 ACT 1 BUF 0)
Quitting (on error).

How can I solve this problem?

Addition:

I set connector:

def _env_to_module(env):
# Create the env-to-module connector pipeline.
return PrevActionsPrevRewards(
multi_agent=True,
n_prev_rewards=1,
n_prev_actions=1,
)

When using this connector, the batch["obs"] passed into the RLModule is no longer a simple tensor, but instead a nested object (likely a custom ViewRequirement or SampleBatch wrapper). For example:

batch[“obs”] = {
“_orig_obs”: tensor(…),
“prev_n_actions”: tensor(…),
“prev_n_rewards”: tensor(…),
}

Meanwhile, other keys in the batch like actions, rewards, state_in, etc., are already split into fixed-length sequences using max_seq_len.

However, batch["obs"] is not split accordingly, which causes mismatch errors like:

RuntimeError: Expected hidden[0] size (1, 10, 64), got [1, 3, 64]

How should I handle the sequence splitting of batch["obs"] manually, or is there an official way to make PrevActionsPrevRewards produce obs that can be aligned and split together with the rest of the batch?

Turns out I need to add FlattenObservations(multi_agent=multi_agent) for the connector:

def _env_to_module(env, multi_agent=True, n_prev_rewards=1, n_prev_actions=1):
    return [
        PrevActionsPrevRewards(
            multi_agent=multi_agent,
            n_prev_rewards=n_prev_rewards,
            n_prev_actions=n_prev_actions,
        ),
        FlattenObservations(multi_agent=multi_agent),
    ]