It’s quite unclear how recurrent models function with the new trajectory view API.
def forward_rnn(
self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
):
print('got state', state)
...
state = [torch.tensor([1,2,3])]
return logits, state
def get_initial_state(self):
return []
got state []
got state []
got state []
got state []
...
I want to build an accumulating representation that scales with the number of episodic timesteps, so various entries in a batch could have different state shapes.
Would anyone be able to provide some guidance on how to do this, or any insight on the interplay between get_initial_state
, forward_rnn
and TrajectoryView
?
If my model inherits from TorchModelV2
instead of RecurrentNetwork
, can I still propagate state?