RecurrentNetwork and Trajectory View API

I’ve been reading over the AttentionNet file for both the GTrXL and AttentionNetWrapper implementations located at ray/rllib/models/torch/attention_net.py at master · ray-project/ray (github.com).

A few things stuck out to me. First, it doesn’t seem like memory_training is actually used anywhere, and only memory_inference is used to configure the trajectory view API. You can see this in the for loop on line 194. So how does RLlib differentiate between training and inference?

Second, the Tr-XL (on which GTrXL is based), as described in the paper, passes past latents into future forward passes to help with long range dependencies. This appears to be the memory_outs array, i.e., the second element of the two-tuple returned by the forward pass of the GTrXL implementation. What is less clear is how that is actually used by either RLlib or the trajectory view API in caching for future forwards passes, i.e., where/how is past memory stored?

Third, assuming that we can always follow the abstractions provided in these implementations regarding returning model latents (e.g., as done on line 446), what considerations, if any, are needed in regards to passing in past memory when implementing a similar model with the trajectory view API? Specifically, it seems that these forward functions all have the signature:

forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    )

Finally, it is also unclear what the shape of the batches given past observations are. It appears that past rewards and such are concatenated along the first dim, but how are past latents done so, and are past states (in their entirety) also concatenated along the first dim (i.e., as additional features)?

Thanks

1 Like