Centralized Critic with separate-layers LSTM to access hidden states in `post_process_trajectories`

This post is related to the previous discussion linked here:
I don’t feel like I should continue to post replies on that thread because it is slightly off-topic from the original post.

In short, I found out that sharing parameters between actor and value networks deteriorate the model’s performance (in this particular project I have been working with). And I still need to use LSTM models as a hard requirement, therefore I need separate branch LSTM actor and LSTM value networks.

Currently, I did the hack around the hidden states, as @mannyv has mentioned in the previous post, forward the value network along with the actor network. This has worked nicely. But eventually, I need to move on to MAPPO, where a centralized critic is needed. In RLlib’s provided example, the value predictions are done in the post_process_trajectories step instead of during forward calls. This global critic is consisted of fully-connected layers and does not concern with the “flow” of hidden states. Whereas in my implementation of separate-branch LSTM, I let forward_rnn to handle the “flow” of hidden states automatically.

I was wondering if there is a way that I can gain access to agents’ hidden states in SampleBatches during the post_process_trajectories step and I can forward my value networks after each rollout instead of during the rollout?

I notice that the view_requirement has been automatically set up for my value networks, and I see the keys state_out_2 and state_out_3 in SampleBatches. But I am not sure how to correctly use them, for example, should I use the state_out_2[i-1] to forward value_function(obs[i], state_out_2[i-1])?

Hope the wording is not too confusing. Deeply appreciate any help.

At time i, state_in_0[i] is passed to forward to compute actions at time i. Forward then returns state_out_0[i], which will be identical to state_in_0[i+1] if the episode isn’t done.

So to compute the value at time i, you probably want to use state_in_0[i].