Initialise loss from dummy batch method in policy.py

Hello, I have a question regarding the _initialize_loss_from_dummy_batch method in policy.py.

For RNN models it seems the state_in_ and state_out are sliced to return just the first 4 elements ( 4 hardcoded) . Is there any reason why this is done that way?

This would mask the remaining states getting in the forward method with all the consequences of that.

Code below.

Thanks a lot , Roberto

    if state_outs:
        B = 4  # For RNNs, have B=4, T=[depends on sample_batch_size]
        i = 0
        while "state_in_{}".format(i) in postprocessed_batch:
            postprocessed_batch["state_in_{}".format(i)] = \
                postprocessed_batch["state_in_{}".format(i)][:B]
            if "state_out_{}".format(i) in postprocessed_batch:
                postprocessed_batch["state_out_{}".format(i)] = \
                    postprocessed_batch["state_out_{}".format(i)][:B]
            i += 1
        seq_len = sample_batch_size // B
        seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)

Hi @rob65 welcome to the community,

is not used during training only during the initial setup of the policy. It is used to try and automatically determine which items and how many timesteps are required in the sample_batch during rollouts and training. This process uses dummy data that is not obtained from the environment.

When it is actually used during training the entire sample_batch is used. Does this help or are you still seeing an issue?

3 Likes

Hi @mannyv , thanks a lot for the info. This makes much more sense now.

But since “B” represents batch_size while “T” represents the time_step length, why does the “B” hardcoded to 4 and “T”, according to the annotation, depends on sample_batch_size?

According to my understanding, the “B” should be the one that depends on sample_batch_size. Could you please kindly shed some light on it?

Thank you for your attention to this question.

I understand this is used for dummy loss calculation, but if we are changing the batch size of only lstm states (state_in/out_0/1), wouldn’t it create mismatch between the batch sizes of lstm states and other variables such as ‘obs’ that are passed to the model?

I am actually getting this error now, because my input_dict[‘obs’][‘frames’] has a batch size of 32, while state has a batch size of 4 in my model’s forward. How to resolve this? I see there is a config option to disable checking for this initial loss on dummy batch:

_disable_initialize_loss_from_dummy_batch = False,

but is it the only way to go?
Thank you.