I am trying to create a custom RL model based on PPO with the use_lstm flag set as true. I have two branches of backbone that takes in different modalities of features. Then, they are concatenated for a subsequent fully-connected layer. It works fine until I set the use_lstm flag as True. The model is as follows:
class BranchedLSTMModel(TorchModelV2, nn.Module):
"""With a separate task head to process task embedding."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
)
nn.Module.__init__(self)
self.obs_internal_model = TorchCNN(
obs_space = orig_space['observations'],
action_space = action_space,
num_outputs = None,
model_config = model_config,
name = name + "_obs_internal",
)
self.task_internal_model = TorchFC(
obs_space = orig_space['tasks'],
action_space = action_space,
num_outputs = None,
model_config = model_config,
name = name + "_task_internal",
)
self.internal_model = TorchFC(
np.zeros((self.obs_internal_model.num_outputs + self.task_internal_model.num_outputs, 1)),
action_space,
num_outputs,
model_config,
name + "_internal",
)
# disable action masking --> will likely lead to invalid actions
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
def forward(self, input_dict, state, seq_lens):
action_mask = input_dict['obs']["action_mask"]
# Compute concatenated obs and task embedding
obs_embed = self.obs_internal_model({"obs": input_dict['obs']["observations"]})
task_embed = self.task_internal_model({"obs": input_dict['obs']["tasks"]})
self._output = torch.cat((obs_embed[0], task_embed[0]), dim = -1)
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": self._output})
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
# if(torch.all(action_mask == 1.0) == False):
# print(masked_logits)
# Return masked logits.
return masked_logits, state
def value_function(self):
assert self._output is not None, "must call forward first!"
return self.internal_model.value_function()
Here’s the error code:
(RolloutWorker pid=98306) File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=98306) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=98306) File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 231, in forward
(RolloutWorker pid=98306) return super().forward(input_dict, state, seq_lens)
(RolloutWorker pid=98306) File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 85, in forward
(RolloutWorker pid=98306) output, new_state = self.forward_rnn(inputs, state, seq_lens)
(RolloutWorker pid=98306) File "/usr/local/lib/python3.6/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 248, in forward_rnn
(RolloutWorker pid=98306) torch.unsqueeze(state[1], 0)])
(RolloutWorker pid=98306) File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
(RolloutWorker pid=98306) result = self.forward(*input, **kwargs)
(RolloutWorker pid=98306) File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 579, in forward
(RolloutWorker pid=98306) self.check_forward_args(input, hx, batch_sizes)
(RolloutWorker pid=98306) File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 530, in check_forward_args
(RolloutWorker pid=98306) self.check_input(input, batch_sizes)
(RolloutWorker pid=98306) File "/home/oaas/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 180, in check_input
(RolloutWorker pid=98306) self.input_size, input.size(-1)))
(RolloutWorker pid=98306) RuntimeError: input.size(-1) must be equal to input_size.
Here’s the model config command:
"model" : {"custom_model":"task_embeded_ta_mask_lstm_model",
"fcnet_hiddens": [64, 64],
"use_lstm": True
},
Any help is appreciated as I am new to rllib. I am not sure what is going wrong here. Does the flag not support branched model like this? If so, how should I go about making this work? Thanks a lot for your help in advance.