'use_lstm' with centralized critic for PPO

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello, I am implementing a centralized value function for PPO following the example for the TwoStepGame and I have managed to make the training work for it without LSTM, but if I set ‘use_lstm’, I am receiving a size mismatch error:

(pid=14475)   File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
(pid=14475)     return forward_call(*input, **kwargs)
(pid=14475)   File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 677, in forward
(pid=14475)     self.check_forward_args(input, hx, batch_sizes)
(pid=14475)   File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 620, in check_forward_args
(pid=14475)     self.check_input(input, batch_sizes)
(pid=14475)   File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 207, in check_input
(pid=14475)     self.input_size, input.size(-1)))
(pid=14475) RuntimeError: input.size(-1) must be equal to input_size. Expected 63, got 276

If we override the PPO model as noted in the example i.e. if we have

CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
    name="CCPPOTorchPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_init=setup_torch_mixins,
    mixins=[
        TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
        CentralizedValueMixin
    ])

Should we still expect the ‘use_lstm’ to work out of the box, or would we need to accommodate the updated model. I have looked into the wrapper’s code, but I am lost since my model itself seems to work given that the training proceeds as expected without ‘use_lstm’