I want to export a TorchModelV2 (ray.rllib.models.torch.fcnet.FullyConnectedNetwork) to TorchScript. I’ve tried the two normal methods: torch.jit.script and torch.jit.trace, but I get errors
torch.jit.script
(Pdb) torch.jit.script(ppo_model)
*** RuntimeError:
Unsupported operation: indexing tensor with unsupported index type 'str'. Only ints, slices, lists and tensors are supported:
File "/home/eadlam/Environments/ray/lib/python3.7/site-packages/ray/rllib/models/torch/fcnet.py", line 115
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"].float()
~~~~~~~~~~~~~ <--- HERE
self._last_flat_in = obs.reshape(obs.shape[0], -1)
self._features = self._hidden_layers(self._last_flat_in)
torch.jit.trace
(Pdb) data = torch.tensor(ppo_model.obs_space.sample())
(Pdb) input = {"obs": data, "obs_flat": data}
(Pdb) traced = torch.jit.trace(ppo_model, input)
*** RuntimeError: mat1 and mat2 shapes cannot be multiplied (260x1 and 260x256)
I’m looking for some ideas on how to make this work.