I’m working on a custom model that has a tuple obs space of length 5, and this is a recurrent model so I’m following the implementation for recurrent net here. When debugging, I found that the dummy batch created at the beginning has batch 32 but seq_len was [1,1,1,1,1]. If my understanding of seq_len is correct, this is probably not correct since 5*1!=32.
Tracing the error, I found this might be from compute_actions_from_input_dict
in torch_policy.py
. In line 297, seq_lens
was created with:
np.array([1] * len(input_dict["obs"])) \
if state_batches else None
which I think is supposed to take the batch size, but doesn’t handle tuple obs space and took the length of the tuple instead. In reality, what should happen is probably something like this:
np.array([1] * len(input_dict["obs"][0])) \
if state_batches else None
Just using index 0 should suffice, since every item of the tuple should be of the same batch size.
Is this the expected behavior? Am I missing anything here?