Hello, I have a question regarding the _initialize_loss_from_dummy_batch method in policy.py.
For RNN models it seems the state_in_ and state_out are sliced to return just the first 4 elements ( 4 hardcoded) . Is there any reason why this is done that way?
This would mask the remaining states getting in the forward method with all the consequences of that.
Code below.
Thanks a lot , Roberto
if state_outs:
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
i = 0
while "state_in_{}".format(i) in postprocessed_batch:
postprocessed_batch["state_in_{}".format(i)] = \
postprocessed_batch["state_in_{}".format(i)][:B]
if "state_out_{}".format(i) in postprocessed_batch:
postprocessed_batch["state_out_{}".format(i)] = \
postprocessed_batch["state_out_{}".format(i)][:B]
i += 1
seq_len = sample_batch_size // B
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)