[RLlib] Workaround for incorrect initial state shape with custom RNN models?

Greetings everyone,

Back on June 21 issue 9071 was opened regarding incorrect initial state shapes when using a custom model in both Tensorflow and Torch. Can I ask the more experienced users here how to correctly set the initial shape (so that it represents the correct batch size)?

Thanks for any tips

This is discussed in depth on ray-project/ray/issues/12509 but using 1.1.0 and the nightly 1.2 the challenge is still present. I’ve not been able to communicate with others about this, so if I find a solution I’ll share it.

For anyone else struggling, I believe I have it running on a custom model by using LSTMWrapper(RecurrentNetwork) as a template for a custom model. It will be interesting when we find why it’s happening in the use case mentioned in the GitHub issue.