[RLlib] Exporting a TorchModelV2 to TorchScript

I want to export a TorchModelV2 (ray.rllib.models.torch.fcnet.FullyConnectedNetwork) to TorchScript. I’ve tried the two normal methods: torch.jit.script and torch.jit.trace, but I get errors

torch.jit.script

(Pdb) torch.jit.script(ppo_model)
*** RuntimeError: 
Unsupported operation: indexing tensor with unsupported index type 'str'. Only ints, slices, lists and tensors are supported:
  File "/home/eadlam/Environments/ray/lib/python3.7/site-packages/ray/rllib/models/torch/fcnet.py", line 115
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs_flat"].float()
              ~~~~~~~~~~~~~ <--- HERE
        self._last_flat_in = obs.reshape(obs.shape[0], -1)
        self._features = self._hidden_layers(self._last_flat_in)

torch.jit.trace

(Pdb) data = torch.tensor(ppo_model.obs_space.sample())
(Pdb) input = {"obs": data, "obs_flat": data}
(Pdb) traced = torch.jit.trace(ppo_model, input)
*** RuntimeError: mat1 and mat2 shapes cannot be multiplied (260x1 and 260x256)

I’m looking for some ideas on how to make this work.