[rllib] SampleBatch "state_in_0" dimension shorter than expected

Hi guys,

I am using sort of a Bayesian filter (a VAE to be exact) instead of an RNN but I would like to use the “state” variable in the RNN interface to store the belief state and I would like to directly access the entire history of (prior) belief state for training.

I though the “state_in_0” column in SampleBatch is used exactly for this purpose, so I added this to my view requirement as follow:

self.view_requirements["state_in_0"] = ViewRequirement(

and I initialized a dummy initial state with zeros and replaced it with a belief state in the action sampler function:

def get_initial_state(self):
    return torch.zeros(1, dummy_dim)

def action_sampler_fn(policy, model, input_dict, state, timestep):
    if tilmestep == 0:
        state = make_actual_belief_state()
    action, logp = model(input_dict, state)
    return action, logo, state

However, when I check sample batch when computing losses, the length of state_in_0 is the number of rollout episodes, while the length of state_out_0 is the number time steps, which is what I wanted.

Because this is very confusing, I would like to get some clarity on the stored variables, when they are stored, and what they are used for.


Your state cannot change shape. It must always be (B,1,dummy_dim) (the shape from get_initial_state)

My expected shape is (episode_timesteps, 1 dummy_dim), but like you said I get (num_rollout, 1, dummy_dim).

What is your max_seq_len? Before passing the model into forwad_rnn two things happen. 1. The data in the sample batch is padded so that all inputs are the same size as your longest episode. So if you had sequences of length [5,12,10,4] they would all be padded with zeros to be 12 steps long. You would end up with a total of12 * 4 timesteps in your sample batch.

They are also shortened to be no larger than max sequence length. Let’s say your max_seq_len in the model config was 20. Your state in would be of size [4,cell_size] since from the rnn perspective you do not need to truncate backprop through time. This is likely what you are seeing. When you are using an rnn you only get the initial state of the sequence. The other states are generated internally by the rnn logic. If your max sequence length was 5 in the example above your would likely have seq_lens of [5,5,5,2,5,5,4] , a sample batch that was padded to 7*5 and a state_in with a shape of [7,cell_size]

The other thing to keep in mind is that there are several passes though the models with dummy data before training starts to calculate view requirements and other values needed by compute action and the loss functions. I have found that sometimes those passes have shapes that I never see during actual traing.

1 Like

Hi thanks for the clarification. I think I got the hang of it now. So state_in only stores the initial hidden cell, then hidden cells generated internally by the RNN will be stored in state_out, is this correct? And then in the case that episodes are of different length, when they appear in your compute_loss_fn() they will be padded to equal length.

Yes. Just a few points of clarification.

During the collection of new episodes, when compute_actions is called. The state_in will hold the value of the state_out of the previous call to compute_actions. If this is the first step of a new episode this will be whatever is returned by the policy in get_initial_state. When compute_actions returns, it will provide logits (not the actual action those are determined by the action distribution that is applied separately to these values) and the resulting state as state_out. State is chained this way from timestep to timestep throughout an episode.

Initially all of the state_in and state_out are saved but when a sample batch is constructed for the loss function, usually, all of the state_in and state_out that correspond to timesteps in an episode where t % max_seq_len ==0 are saved and the states between them are discarded to save space.

The other thing that happens, as you noted corectly is that all of the sequences are padded to be the same length as the largest sequence in the samplebatch.

In the loss functions for algorithms that support RNN there is a process after the loss is computed for each step that will zero out the padded timesteps so that they do not contribute to the loss or the backward passes to compute the gradients. Here is an example from PPO.