How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I am relatively new to RLlib. I have restored a checkpoint and saved my (non RNN) custom DQN model for inference using export_model. This generates a torchscript model (model.pt). One subtlety is that this model expects tensor inputs for (state, seq_lens) so ([ ], None) will throw an error. However, if an appropriate dummy input is provided, we can do a forward pass.
The problem is that the code:
model = torch.jit.load(model_path)
model.eval()
obs = torch.rand((1, 128, 111))
obs_dict = {“obs”: obs}
state = np.array([1.0])
state = [torch.from_numpy(state)]
seq_lens = np.ndarray((1,))
seq_lens = torch.from_numpy(seq_lens)
model(obs_dict, state, seq_lens)
has the wrong output shape, in this case 256, and not the model output shape. Can anyone direct me towards the solution to this?
If it helps, here are the final layers in print(model):
(to_logits): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=LayerNorm)
(1): RecursiveScriptModule(original_name=Linear)
)
(value_branch): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=LayerNorm)
(1): RecursiveScriptModule(original_name=Linear)
)
(advantage_module): RecursiveScriptModule(
original_name=Sequential
(dueling_A_0): RecursiveScriptModule(
original_name=SlimFC
(_model): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Linear)
(1): RecursiveScriptModule(original_name=ReLU)
)
)
(A): RecursiveScriptModule(
original_name=SlimFC
(_model): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Linear)
)
)
)
and this is the output to print(model.code):
def forward(self,
input_dict: Dict[str, Tensor],
state: List[Tensor],
seq_lens: Tensor) → Tuple[Tensor, List[Tensor]]:
to_logits = self.to_logits
hidden_layers = self.hidden_layers
obs = input_dict[“obs”]
_0, = state
_1 = (to_logits).forward((hidden_layers).forward(obs, ), )
return (_1, [_0])
Since this whole method is used for inference, shouldn’t the output of
model(obs_dict, state, seq_lens)
be of shape num_outputs where I guess argmax corresponds exactly to policy.compute_single_action(obs)[0] ?
In other words, how can I get this single action output through a forward pass of a saved torchscript model?