Understanding state_batches in compute_actions


I am working on a custom policy that uses the state_batches in compute_actions to keep track of an internal policy state, which gets updated at each timestep of the environment (think of an updated expectation value of observations). I use the Trajectory View API with the following settings in my policy’s __init__():

self.view_requirements['state_in_0'] = \

The initial state is defined as:

# Initial state in custome policy:
def get_initial_state(self):
    return [np.zeros(8,dtype=np.float64)]

This initial state btw gets already shaped to the following when arriving in compute_actions():

# What happened here? 
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)]

Furthermore, I return as second object in the compute_actions() function of my policy a list of BATCH_SIZE numpy arrays of shape (STATE_SIZE,).

When I analyze the SampleBatch of an episode after training I can see that the state_out_0 is not identical to the state_in_0 of the next timestep (why? normalization?):


[[ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.05122958  1.07819295  1.02973902  1.037444    0.89701152 -0.04827012
  -0.0305887  -0.01290726] # <- this should equal ] 

[[0.02288876 0.58843416 0.31873515 0.56406111 0.29287788 0.
  0.         0.        ] # <- this
 [0.0270972  0.74560356 0.45348084 0.71006745 0.40950587 0.0121695
  0.00722509 0.00228068]]

I took a look at the definition of compute_actions() which returns the new state_batches in shape [STATE_SIZE, BATCH_SIZE] and type List[TensorType]. So I thought I have to change the output shape and did so:

[array([0.02288876]), array([0.58843414]), array([0.31873516]), array([0.56406113]), array([0.29287789]), array([0.]), array([0.]), array([0.])]

However, in the next timestep the state_in_0 variable has the shape:


which gives necessarily an error. I am confused. Can anyone tell me, how to correctly define the initial state and return the state_batches? (maybe give a hint where in the source code to find the processing of the state_batches)

Thanks for your help

1 Like

Hi @Lars_Simon_Zehnder,

I saw you liked my post on states in rllib.

In there I mentioned

“When you are using an rnn you only get the initial state of the sequence. The other states are generated internally by the rnn logic.”

Thisius likely the issue. To verify this, check the shape of [obs, seq_lens, and state_in_0[0]] and see how they compare. If this is the issue you will see that obs batch size is larger than seq_lens and state_in_0 is equal to seq_lens.

An option that might work is to use a different name for the trajectory view variable that tracks state_in that does not include state in the key. This might circumvent the special state handling when preparing the sample batch. I am just guessing here.

1 Like

@mannyv thank you very much for your reply! Yes, I liked your post on states - it made a lot clear to me, while I am still starting to dig into RLlib. As I am working with the state_batches without an RNN, but simply to keep track of the internal state of a policy, the explanation with the sequences is not immediately applicable to my problem, however - and as so often :smiley: you spelled it out for me - max_seq_len was at the end the solution.

Maybe this helps others to better understand what is happening with the state_in_0 and state_out_0 variables. I first investigated the function build() in simple_list_collector.py. While debugging this function I found out that the batch_repeat_value attribute of the ViewRequirement of state_in_0 had already been put on the default 20. Then I modified my ViewRequirement with a batch_repeat_value of 1 (as @mannyv pointed out in his linked post, take each step’s state_in_0 in the sample collecting at the end of a rollout), but this did not change anything.

Next step, I checked for code that handles the ViewRequirement of state_in_0 and found the most important one for my problem in the function _update_model_view_requirements_form_init_state() and therein the lines:

view_reqs["state_in_{}".format(i)] = ViewRequirement(
          batch_repeat_value=self.config.get("model", {}).get(
            "max_seq_len", 1),

where the ViewRequirement for state_in_0 gets defined again and actually user-defined values overwritten. That does not mean that it might not be the authors’ intention though as this function somehow automates the Trajectory View API for the user and simplifies implementations of RNNs or Attention Nets (Is it actually intended to @sven1977 ?). The solution for my problem is of course to use max_seq_len in the model config. Now all steps in state_in_0 in a rollout get considered and the state_out_0 of timestep t is identical to state_in_0 at timestep t+1.

The yet open question is: I use as a return shape for state_batches in compute_actions() [BATCH_SIZE, STATE_SIZE]=[1,8], however that appears to be wrong in regard to documentation where it demands it the other way around, namely [STATE_SIZE,BATCH_SIZE]. What is correct?


Hey @Lars_Simon_Zehnder , the documentation is incorrect here. State vectors are always:
[B, state-dim], not the other way around. Will fix this in the comment. Thanks for raising this!


Hi @sven1977 ,
thanks for the clarification in regard to the shape of the state vector. That explains now the behavior I see.

What is about the overwriting of ViewRequirements in _update_model_view_requirements_form_init_state()? Is this intended?

Hey @sven1977 @mannyv,

I actually have to reopen this post. The reason for this is that I want to include now two shifts of my state in the compute_action() function. In my policy’s __init__() method I write the following:

self.view_requirements['state_in_0'] = \
self.view_requirements['state_out_0'] = \

If I debug my code I get from my policy:


which shows to me that my view_requirements got overwritten. This points towards my reply above stating:

Next step, I checked for code that handles the ViewRequirement of state_in_0 and found the most important one for my problem in the function _update_model_view_requirements_form_init_state() and therein the lines: (pls find the reply above).

Question: Now, is there a way to not get one’s own view_requirements in a policy overwritten?

Thanks for any time and effort you put into this.

This has been fixed in pull request 17867. For anyone using bare metal policies (without a model and directly inheriting from Policy) this is good to know.

See also pull request 17896 for an example of a bare metal policy with user-specific view_requirements for the state of a policy.

1 Like

@Lars_Simon_Zehnder nice work. I was away for a few weeks and I just wanted to make one more comment on another way you could do this if you were not using a totally custom policy.

In that case you can create your own custom ViewRequirement with a unique key name that is derived from state_. If you did it this way then it would not mater that state_ is being overridden. That approach would also work when using an rllib agent algorithm with the standard RNN interface. Let’s say you wanted to use PPO but also add a custom loss that required the last 5 states. The PPO could use the standard state_* keys with only 1 past state and the custom loss could use the new View Requirement.

1 Like