[RLlib] Exporting a PyTorch policy for TorchScript

Hello all,

I’m trying to use Ray/RLLib to train a policy, and I’m running into trouble when it comes time to export it. I’m using Ray 1.0.1.post1. My workflow uses the Python API and a custom environment, and iteratively calling the train() method on a Trainer object. After it completes, I have a model buried somewhere in the trainer object. Here’s what I’ve done to try to free it:

  • If I call trainer.export_model(ExportFormat.MODEL, args.export_dir) I get a NotImplementedError:

File “/home/amcleod/.local/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py”, line 590, in export_model
raise NotImplementedError

  • I can get a policy object with trainer.get_policy(), but this object can’t be pickled.

Ideally, I would like to find a way to end up with a torch module containing the trained policy exclusively, that is free of the Ray class hierarchy. How would I go about doing this?


Did you ever figure out how to do this?

To implement this missing method for TorchPolicy is harder than I hoped.

The seemingly simple solution is to add this to the TorchPolicy class:

    def export_model(self, export_dir: str) -> None:
        """Exports the Policy's Model to local directory for serving.

        Creates a TorchScript model and saves it.

            export_dir (str): Local writable directory or filename.
        dummy_inputs = self._lazy_tensor_dict(self._dummy_batch.data)
        # Provide dummy state inputs if not an RNN (torch cannot jit with empty list).
        if "state_in_0" not in dummy_inputs:
            dummy_inputs["state_in_0"] = dummy_inputs["seq_lens"] = np.array([1.0])
        dummy_inputs = {k: dummy_inputs[k] for k in dummy_inputs.keys()}
        traced = torch.jit.trace(self.model, dummy_inputs)
        if os.path.isfile(export_dir):
            file_name = export_dir
            file_name = os.path.join("model.pt", export_dir)

However, torch jit requires the all return values of the nn.Module to be tensors, which is not the case for our TorchModelV2, which - if not an RNN - returns an empty list of internal states (as second return value). Removing this for non-RNNs would break our entire ModelV2 API. :confused:

Hmm, I actually did find a way, but it would require you to pass in a fake state_in_0 (not []!) and seq_lens tensor (not None!). I guess this is better than not having this work at all.
We may have to change the Model API (or provide an alternative one) at some point to make this work properly.
I’ll PR. …

1 Like
1 Like