GRU hidden_state tensor batch dimension is incompatible with sample_batch

This is hopefully a quick question:

My custom model is dying on the third query because the hidden-state of the gru doesn’t match in the batch dimension where the size of the hidden-state seems to be len(seq_len) and not sum(seq_len) (Similarly in e.g. ray\rllib\agents\ppo\ppo_torch_policy.py after the forward call, if state is returned, then we get the batch size by taking the length of the seq_len tensor.):

(pid=29704) seq_len tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(pid=29704)         1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([32, 2704])
(pid=29704) h.shape torch.Size([32, 2704]), logits.shape torch.Size([32, 38]), self.value.shape torch.Size([32, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([1], dtype=torch.int32)
(pid=29704) x.shape torch.Size([1, 2704]), h_in.shape torch.Size([1, 2704])
(pid=29704) h.shape torch.Size([1, 2704]), logits.shape torch.Size([1, 38]), self.value.shape torch.Size([1, 1])
(pid=29704) ____________________________________________________________
(pid=29704) seq_len tensor([8, 8, 8, 8], dtype=torch.int32)
(pid=29704) x.shape torch.Size([32, 2704]), h_in.shape torch.Size([4, 2704])

In the third step, why is seq_len four 8s rather than thirty-two 1s?

I followed the convention in the RecurrentNetwork class example here: ray\rllib\models\torch\recurrent_net.py

    @override(RecurrentNetwork)
    def forward_rnn(self, x, state, seq_lens):
        x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
        x = self.conv(x)
        x = self.flat(x)
        h_in = state[0].reshape(-1, self.cell_size)
        print(f"seq_len {seq_lens}")
        print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
        h = self.gru(x, h_in)
        self.value = self.val(h)
        logits = self.fc(h)
        print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
        print('_'*60)
        return logits, [h]

A hacky way to get around that would be to h_stack the hidden_state vector onto itself the necessary number of times to match the batch_size of the input data.

    @override(RecurrentNetwork)
    def forward_rnn(self, x, state, seq_lens):
        x = x.reshape(-1, self._num_objects, self.obs_space.shape[0], self.obs_space.shape[1])
        x = self.conv(x)
        x = self.flat(x)
        h_in = state[0].reshape(-1, self.cell_size)
        print(f"seq_len {seq_lens}")
        print(f"x.shape {x.shape}, h_in.shape {h_in.shape}")
        if not h_in.shape[0] == x.shape[0]:
            missing_h = self.conv.weight.new(x.shape[0] - h_in.shape[0], h_in.shape[1]).zero_()
            h_in = torch.vstack((h_in, missing_h))
        h = self.gru(x, h_in)
        self.value = self.val(h)
        logits = self.fc(h)
        print(f"h.shape {h.shape}, logits.shape {logits.shape}, self.value.shape {self.value.shape}")
        print('_'*60)
        return logits, [h]

This seems to work (from a getting all the shapes to line up perspective).

1 Like

With this change, the code seems to run using the PPOTrainer until it hits this value error:

ValueError('cannot reshape array of size 8 into shape (2704)\nIn tower 0 on device cpu') with no additional details about the stack trace even in local mode.

Hey @aadharna, could you post either your:
a) get_initial_state method of your custom Model (this must return a list of non-batched (no batch dim!) internal state tensors).
or
b) the self.view_requirements that your custom Model generates in its c’tor (if applicable).

I think the error lies somewhere there.

Here is the get_initial_state method:

    @override(ModelV2)
    def get_initial_state(self):  # -> List[np.ndarray]:
        """Get the initial recurrent state values for the model.

        Returns:
            List[np.ndarray]: List of np.array objects containing the initial
                hidden state of an RNN, if applicable.

        """
        h = [self.conv.weight.new(
                1, self.cell_size).zero_()]
        return h

And here is the view requirements:

example_network_factory_build_info = {
        'action_space': MultiDiscrete([15, 15,  6,  2]),
        'obs_space': Box(0.0, 255.0, (15, 15, 4), np.float64),
        'model_config': {'length': 15, 'width': 15, 'placements': 50},
        'num_outputs': sum([15, 15,  6,  2]),
        'name': 'pcgrl'
    }
adversary.view_requirements
Out[3]: {'obs': <ray.rllib.policy.view_requirement.ViewRequirement at 0x1c2f2a43a08>}
adversary.view_requirements['obs'].space
Out[4]: Box(0.0, 255.0, (15, 15, 4), float64)

When I return: [self.conv.weight.new(1, self.cell_size).zero_().squeeze(0)] in the get_initial_state function, and remove the hacky batch fixing code, it results in an error: Input batch size 32 doesn't match hidden0 batch size 4

When I add back in the

        if not h_in.shape[0] == x.shape[0]:
            missing_h = self.conv.weight.new(x.shape[0] - h_in.shape[0], h_in.shape[1]).zero_()
            h_in = torch.vstack((h_in, missing_h))

and have the squeeze, it seems to run. But that still leaves my original question of: why is seq_len four 8s rather than thirty-two 1s since it seems that the batch size is determined by len(seq_len) rather than sum(seq_len)?