I’m using LSTM + PPO in RLlib’s new API stack. I’m running into a problem inside def _compute_embeddings_and_state_outs(self, batch)
. During rollout, everything works fine. But when the data is passed to the learner for training, the LSTM input (obs
) no longer matches the size of state_in['h']
and state_in['c']
.
From what I understand, RLlib automatically splits episodes into fixed-length sequences (e.g., using max_seq_len
) and adjusts state_in
, actions
, rewards
, etc., accordingly — but it does not split batch["obs"]
, which still contains the full uncut episode. This mismatch causes shape errors during training.
Here is my _compute_embeddings_and_state_outs()
function:
def _compute_embeddings_and_state_outs(self, batch):
obs = batch[“obs”][“_orig_obs”][…, :self.state_size]
context = batch[“obs”][“_orig_obs”][…, self.state_size:]
action = batch[“obs”][“prev_n_actions”].squeeze(-2)
reward = batch[“obs”][“prev_n_rewards”].squeeze(-1).unsqueeze(-1)
input = torch.cat([obs, context, action, reward], dim=-1)
h = batch["state_in"]["h"].unsqueeze(0) # [1, B, H]
c = batch["state_in"]["c"].unsqueeze(0) # [1, B, H]
lstm_output, (h, c) = self.lstm(input, (h, c))
embeddings = self.hidden(torch.cat([lstm_output, obs], dim=-1))
return embeddings, {'h': h.squeeze(0), 'c': c.squeeze(0)}
For example, I set:
batch_mode = “complete_episodes”
train_batch_size_per_learner = 120
episode_length = 50
For the features in batch:
Key: obs
This is a nested dict:
_orig_obs: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 13])
prev_n_actions: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 1, 6])
prev_n_rewards: type=<class ‘torch.Tensor’>, shape=torch.Size([3, 69, 1])
Key: state_in
This is a nested dict:
h: type=<class ‘torch.Tensor’>, shape=torch.Size([9, 64])
c: type=<class ‘torch.Tensor’>, shape=torch.Size([9, 64])
and the error will be:
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (1, 3, 64), got [1, 9, 64]
Error: tcpip::Socket::recvAndCheck @ recv: peer shutdown49 ACT 1 BUF 0)
Quitting (on error).
How can I solve this problem?