Save RNN model's cell and hidden state

Hi all,

in case a policy model contains a RNN or more precise an LSTM-cell, then trainer.save() stores the weights of all trainable variables. However, trainer.save() doesn’t store recent cell and hidden state of the LSTM-cell.
Is there a way to store these tensors (at least cell state), too? Or doesn’t make this sense, I’m not sure about that?! For my understanding, cell state is a long term memory and thus it might be helpful.

Hi @klausk55 ,

this is correct the cell state and hidden state do not get stored by trainer.save(). A way to get your states stored is by using RLlib’s Offline API. Just add a file path to the parameter output in your trainer config and it stores all sample batches therein you get the state_out_0 and state_out_1 which should be your hidden and cell state.

Hope this helps

1 Like

Hello @Lars_Simon_Zehnder,

that sounds good to me! I’ve already done this but hadn’t noticed it :sweat_smile:

What do you think, is it reasonable to load some of the internal RNN states when employing the previously trained policy model? Or what do you use as an initial state when you employ a trained model?

Hi @klausk55 ,

this is a great question! As far as I understood it, the hidden state and cell state get reset during evaluation. This makes sense in my opinion, as the hidden state and cell state should at the most keep information from the running episode in memory.

A new episode describes a new path through the MDP and therefore the memory makes sense during an episode but not between several ones. This is an interesting discussion though. Maybe @mannyv , @arturn and @sven1977 have different opinions or something to add here? I am excited to hear.

1 Like

Hi @klausk55, Hi @Lars_Simon_Zehnder ,

Without having done any experiments: I would argue intuitively that brining the hidden state to the next episode would rather hinder learning than promoting it because the state would encode information that is simply wrong. So if your cell has learned to encode some cool information, and if you bring the hidden state to the next episode, the gradients produced by training on this will likely work against what your network has learned before about what is encoded in that hidden state. Instead, the first hidden state should encode solely that there is no information from prior steps, for example with a bunch of 0s.

Again, this is only my intuition :slight_smile:

Cheers!

2 Likes

Hi @arturn,

your argumentation sounds good to my ears.
However, what’s about the cell state? Would it be a similar behavior for the cell state as for the hidden state? To my intuition, the cell state encodes some information about previous steps (parts of recent history). Therefore, someone again could argue that this “old information” in a “new situation” is rather counterproductive than helpful for a learned model.

I would appreciate to hear some further thoughts on this!

1 Like

Hi @arturn ,

interesting answer! Thanks for replying here! My intuition was quite similar. I thought that walking a different path in the MDP does usually involve very different information due to the randomness. Keeping the state would result in stale information - not useful on the next path of the MDP. Your interpretation with the gradient and no information existent is really nicely layed out.

@klausk55 I hope that helps you to make your decision.

1 Like

Hi @klausk55 ,

great question again! In my opinion the same argument holds here. The cell state usually keeps information a little longer, but is also fed by information of the actual episode and therefore specifically relevant for this episode. My intuition tells me that this information in the cell state will still be stale in another episode. Especially if episodes are very long. Interesting experiment though. I argue that the longest memory is actually contained in the weights and these contain gradient information - which are often quite good. So, initialising the cell state and hidden state while weights are kept fixed should give the best results.

1 Like

Hi @Lars_Simon_Zehnder,

my episodes are really, really long! Currently, I reset my env after 24 hours have passed in my simulation.
Thanks guys for the interesting and helpful discussion! It seems that the best choice is

def get_initial_state(self):
    # initial cell and hidden state
    return [np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32)]
2 Likes

@klausk55 ,

great discussion indeed! Thanks for bringing this in!

2 Likes