Loaded Torchscript models have wrong output shape

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I am relatively new to RLlib. I have restored a checkpoint and saved my (non RNN) custom DQN model for inference using export_model. This generates a torchscript model (model.pt). One subtlety is that this model expects tensor inputs for (state, seq_lens) so ([ ], None) will throw an error. However, if an appropriate dummy input is provided, we can do a forward pass.

The problem is that the code:

model = torch.jit.load(model_path)
model.eval()
obs = torch.rand((1, 128, 111))
obs_dict = {“obs”: obs}
state = np.array([1.0])
state = [torch.from_numpy(state)]
seq_lens = np.ndarray((1,))
seq_lens = torch.from_numpy(seq_lens)
model(obs_dict, state, seq_lens)

has the wrong output shape, in this case 256, and not the model output shape. Can anyone direct me towards the solution to this?

If it helps, here are the final layers in print(model):

(to_logits): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=LayerNorm)
(1): RecursiveScriptModule(original_name=Linear)
)
(value_branch): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=LayerNorm)
(1): RecursiveScriptModule(original_name=Linear)
)
(advantage_module): RecursiveScriptModule(
original_name=Sequential
(dueling_A_0): RecursiveScriptModule(
original_name=SlimFC
(_model): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Linear)
(1): RecursiveScriptModule(original_name=ReLU)
)
)
(A): RecursiveScriptModule(
original_name=SlimFC
(_model): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Linear)
)
)
)

and this is the output to print(model.code):

def forward(self,
input_dict: Dict[str, Tensor],
state: List[Tensor],
seq_lens: Tensor) → Tuple[Tensor, List[Tensor]]:
to_logits = self.to_logits
hidden_layers = self.hidden_layers
obs = input_dict[“obs”]
_0, = state
_1 = (to_logits).forward((hidden_layers).forward(obs, ), )
return (_1, [_0])

Since this whole method is used for inference, shouldn’t the output of

model(obs_dict, state, seq_lens)

be of shape num_outputs where I guess argmax corresponds exactly to policy.compute_single_action(obs)[0] ?

In other words, how can I get this single action output through a forward pass of a saved torchscript model?

Uh oh, that’s not great. @gjoliver is working on improving the inference path. Can someone on RLlib side take a look?

BTW @GallantWood do you think you could provide a copy of the actual model somewhere? i.e., load the model_path somewhere we can download?

I suspect what you get are the raw logits.
RLlib policies have build-in action distribution functions to sample an action from the distributions constructed based on these raw logit outputs.
can you open a new github issue with a stripped down script demonstrating this problem?
with the actual model config, we can help make sure.

btw, we will be providing utils to make it much easier to run a checkpointed policy. having to come up with dummy state and seq_lens is painful.

Thank you @rliaw @gjoliver
I have built a simple reproducible example using cartpole. I will open a github issue with the code.

@gjoliver @rliaw
Here is a link to the issue with complete script demonstrating the problem:

1 Like

ok cool, let’s move our discussion there.