This is hopefully a quick question:
My custom model is dying on the third query because the hidden-state of the gru doesn’t match in the batch dimension where the size of the hidden-state seems to be len(seq_len) and not sum(seq_len) (Similarly in e.g. ray\rllib\agents\ppo\ppo_torch_policy.py after the forward call, if state is returned, then we get the batch size by taking the length of the seq_len tensor.):
(pid=29704) seq_len tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(pid=29704)         1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([32, 2704])
(pid=29704) h.shape torch.Size([32, 2704]), logits.shape torch.Size([32, 38]), self.value.shape torch.Size([32, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([1, 2704]), h_in.shape torch.Size([1, 2704])
(pid=29704) h.shape torch.Size([1, 2704]), logits.shape torch.Size([1, 38]), self.value.shape torch.Size([1, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([8, 8, 8, 8], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([4, 2704])
In the third step, why is seq_len four 8s rather than thirty-two 1s?
I followed the convention in the RecurrentNetwork class example here: ray\rllib\models\torch\recurrent_net.py
    @override(RecurrentNetwork)
    def forward_rnn(self, x, state, seq_lens):
        x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
        x = self.conv(x)
        x = self.flat(x)
        h_in = state[0].reshape(-1, self.cell_size)
        print(f"seq_len {seq_lens}")
        print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
        h = self.gru(x, h_in)
        self.value = self.val(h)
        logits = self.fc(h)
        print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
        print('_'*60)
        return logits, [h]
A hacky way to get around that would be to h_stack the hidden_state vector onto itself the necessary number of times to match the batch_size of the input data.
    @override(RecurrentNetwork)
    def forward_rnn(self, x, state, seq_lens):
        x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
        x = self.conv(x)
        x = self.flat(x)
        h_in = state[0].reshape(-1, self.cell_size)
        print(f"seq_len {seq_lens}")
        print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
        if not h_in.shape[0] == x.shape[0]:
            missing_h = self.conv.weight.new(x.shape[0] - h_in.shape[0], h_in.shape[1]).zero_()
            h_in = torch.vstack((h_in, missing_h))
        h = self.gru(x, h_in)
        self.value = self.val(h)
        logits = self.fc(h)
        print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
        print('_'*60)
        return logits, [h]
This seems to work (from a getting all the shapes to line up perspective).