Branched torch model with use_lstm - Shape mismatch

I am trying to create a custom RL model based on PPO with the use_lstm flag set as true. I have two branches of backbone that takes in different modalities of features. Then, they are concatenated for a subsequent fully-connected layer. It works fine until I set the use_lstm flag as True. The model is as follows:

class BranchedLSTMModel(TorchModelV2, nn.Module):
    """With a separate task head to process task embedding."""

    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)

        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )
        nn.Module.__init__(self)

        self.obs_internal_model = TorchCNN(
            obs_space = orig_space['observations'],
            action_space = action_space,
            num_outputs = None,
            model_config = model_config,
            name = name + "_obs_internal",
        )

        self.task_internal_model = TorchFC(
            obs_space = orig_space['tasks'],
            action_space = action_space,
            num_outputs = None,
            model_config = model_config,
            name = name + "_task_internal",            
        )



        self.internal_model = TorchFC(
            np.zeros((self.obs_internal_model.num_outputs + self.task_internal_model.num_outputs, 1)),
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )  

        # disable action masking --> will likely lead to invalid actions
        self.no_masking = False
        if "no_masking" in model_config["custom_model_config"]:
            self.no_masking = model_config["custom_model_config"]["no_masking"]

    def forward(self, input_dict, state, seq_lens):
        action_mask = input_dict['obs']["action_mask"]

        # Compute concatenated obs and task embedding 
        obs_embed = self.obs_internal_model({"obs": input_dict['obs']["observations"]})
        task_embed = self.task_internal_model({"obs": input_dict['obs']["tasks"]})
        self._output = torch.cat((obs_embed[0], task_embed[0]), dim = -1)

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": self._output})


        # If action masking is disabled, directly return unmasked logits
        if self.no_masking:
            return logits, state



        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)

        masked_logits = logits + inf_mask
        # if(torch.all(action_mask == 1.0) == False):
        #     print(masked_logits)

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        assert self._output is not None, "must call forward first!"
        return self.internal_model.value_function()

Here’s the error code:

(RolloutWorker pid=98306)   File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=98306)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=98306)   File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 231, in forward
(RolloutWorker pid=98306)     return super().forward(input_dict, state, seq_lens)
(RolloutWorker pid=98306)   File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 85, in forward
(RolloutWorker pid=98306)     output, new_state = self.forward_rnn(inputs, state, seq_lens)
(RolloutWorker pid=98306)   File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 248, in forward_rnn
(RolloutWorker pid=98306)     torch.unsqueeze(state[1], 0)])

(RolloutWorker pid=98306)   File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
(RolloutWorker pid=98306)     result = self.forward(*input, **kwargs)
(RolloutWorker pid=98306)   File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 579, in forward
(RolloutWorker pid=98306)     self.check_forward_args(input, hx, batch_sizes)
(RolloutWorker pid=98306)   File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 530, in check_forward_args
(RolloutWorker pid=98306)     self.check_input(input, batch_sizes)
(RolloutWorker pid=98306)   File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 180, in check_input
(RolloutWorker pid=98306)     self.input_size, input.size(-1)))
(RolloutWorker pid=98306) RuntimeError: input.size(-1) must be equal to input_size.

Here’s the model config command:

        "model" : {"custom_model":"task_embeded_ta_mask_lstm_model",
                    "fcnet_hiddens": [64, 64],
                    "use_lstm": True
        },

Any help is appreciated as I am new to rllib. I am not sure what is going wrong here. Does the flag not support branched model like this? If so, how should I go about making this work? Thanks a lot for your help in advance.

Hey @richielo , thanks for posting this issue. Could you actually provide a fully self-sufficient reproduction script that contains all necessary code to reproduce the error you are seeing? This would help us a lot debug and fix the problem. Thanks! :slight_smile:

Hello @sven1977 Thank you for getting back to me. I will try to come back with an env placeholder to reproduce this as I am prohibited from putting up my current environment. I think the key problem here is the shape issue when you use the use_lstm flag + action masking (I am using the approach suggested in rllib’s sample code for masking). Specifically, when use_lstm flag is true, the following line breaks due to shape mismatch, I believe I am no longer getting logits:

masked_logits = logits + inf_mask

It seems like I have to override some other functions to perform masking when there’s a recurrent layer. But it is unclear to me which functions I should override. If you could provide me with an example of action masking when use_lstm is true. That’d be super helpful. Thanks a lot