Initialise loss from dummy batch method in policy.py

Hello, I have a question regarding the _initialize_loss_from_dummy_batch method in policy.py.

For RNN models it seems the state_in_ and state_out are sliced to return just the first 4 elements ( 4 hardcoded) . Is there any reason why this is done that way?

This would mask the remaining states getting in the forward method with all the consequences of that.

Code below.

Thanks a lot , Roberto

    if state_outs:
        B = 4  # For RNNs, have B=4, T=[depends on sample_batch_size]
        i = 0
        while "state_in_{}".format(i) in postprocessed_batch:
            postprocessed_batch["state_in_{}".format(i)] = \
                postprocessed_batch["state_in_{}".format(i)][:B]
            if "state_out_{}".format(i) in postprocessed_batch:
                postprocessed_batch["state_out_{}".format(i)] = \
                    postprocessed_batch["state_out_{}".format(i)][:B]
            i += 1
        seq_len = sample_batch_size // B
        seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)

Hi @rob65 welcome to the community,

is not used during training only during the initial setup of the policy. It is used to try and automatically determine which items and how many timesteps are required in the sample_batch during rollouts and training. This process uses dummy data that is not obtained from the environment.

When it is actually used during training the entire sample_batch is used. Does this help or are you still seeing an issue?

1 Like

Hi @mannyv , thanks a lot for the info. This makes much more sense now.