How to reset rnn states at episode end in a torch model?

I have a simple model with a few Fully Connected layers and an LSTM layer in between. The method interfaces are well explained in the documentation (e.g. forward_rnn) and I was able to implement them. However, I wonder how the hidden state of the recurrent model is reset at the end of an episode. Does RLlib do this automatically (and if so where) or do I have to take care of it myself? What are best practices here?

1 Like

Hey @LukasNothhelfer , RLlib automatically resets the internal state at the beginning of an episode. Note that the internal state is not saved inside the model, but “carried” by the RolloutWorkers and its SampleCollectors. At the beginning of an episode, RLlib uses the initial state defined either by the model via its get_initial_state method, or by the model’s view-requirements dict.
In other words, you should be fine. :slight_smile: You can maybe print out the state tensors being passed into your forward passes to confirm they are all 0.0s (or whatever init value you defined) at the beginning of each episode?

Hello @sven1977 , thanks for the quick feedback. How can I find out from the input_dict passed to the forward method, or from inputs passed to the forward_rnn method, at which index in the batch a new episode was started and thus zeros are to be expected in the corresponding states? I only get the done flag in input_dict and in the inputs for forward_rnn this info is not present at all. Additionally, I would be interested to know where in the source code this mechanism is built in to reset the hidden state at the beginning of an episode. Can you help me here?

BTW, I think a hint in the documentation would be helpful for many, as it eliminates many questions especially for newcomers like me. In tensorforce, for example, appropriate hints are given (see Layers — Tensorforce 0.6.3 documentation, there it says: “RNN consequently maintains a temporal internal state over the course of an episode.”

The function that ultimately resets the state for a new episode is here:

Essentially what happens is that when an episode is finished, a new one is created at the code linked below. When that episode is created it will not have a state entry for any agents since they have never stepped the environment. So what happens, is it calls get_initial_state for the policy mapped to that agent.

As for knowing when you are at the beginning of an episode when you are training I do not know a good way to get that info.

Hello @mannyv. Thx for your reply. I recently stumbled across the code snippet as well. I was wondering why during my debug sessions, the get_initial_state method is only called to create the trainer and not after. I would at least expect it to be called more often (namely, right when a new episode starts). I set breakpoints with the ray debugger but during the run this method was simply never called. I will check this and inform the discussion here accordingly if I could confirm this. Thanks for your help.

Yeah there was a bug that is fixed in the nightly builds but has not made it into a release yet.

Details are here:

1 Like

@mannyv Thanks, that helps a lot