This is hopefully a quick question:
My custom model is dying on the third query because the hidden-state of the gru doesn’t match in the batch dimension where the size of the hidden-state seems to be len(seq_len) and not sum(seq_len) (Similarly in e.g. ray\rllib\agents\ppo\ppo_torch_policy.py
after the forward call, if state is returned, then we get the batch size by taking the length of the seq_len tensor.):
(pid=29704) seq_len tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(pid=29704) 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([32, 2704])
(pid=29704) h.shape torch.Size([32, 2704]), logits.shape torch.Size([32, 38]), self.value.shape torch.Size([32, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([1, 2704]), h_in.shape torch.Size([1, 2704])
(pid=29704) h.shape torch.Size([1, 2704]), logits.shape torch.Size([1, 38]), self.value.shape torch.Size([1, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([8, 8, 8, 8], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([4, 2704])
In the third step, why is seq_len four 8s rather than thirty-two 1s?
I followed the convention in the RecurrentNetwork class example here: ray\rllib\models\torch\recurrent_net.py
@override(RecurrentNetwork)
def forward_rnn(self, x, state, seq_lens):
x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
x = self.conv(x)
x = self.flat(x)
h_in = state[0].reshape(-1, self.cell_size)
print(f"seq_len {seq_lens}")
print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
h = self.gru(x, h_in)
self.value = self.val(h)
logits = self.fc(h)
print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
print('_'*60)
return logits, [h]
A hacky way to get around that would be to h_stack the hidden_state vector onto itself the necessary number of times to match the batch_size of the input data.
@override(RecurrentNetwork)
def forward_rnn(self, x, state, seq_lens):
x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
x = self.conv(x)
x = self.flat(x)
h_in = state[0].reshape(-1, self.cell_size)
print(f"seq_len {seq_lens}")
print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
if not h_in.shape[0] == x.shape[0]:
missing_h = self.conv.weight.new(x.shape[0] - h_in.shape[0], h_in.shape[1]).zero_()
h_in = torch.vstack((h_in, missing_h))
h = self.gru(x, h_in)
self.value = self.val(h)
logits = self.fc(h)
print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
print('_'*60)
return logits, [h]
This seems to work (from a getting all the shapes to line up perspective).