Best practice to share a torch model across actors

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I have a scenario where the torch model is trained in the main process and then evaluated in n ray actors. How can I have the ray actors always use the up-to-date model without explicitly sending it to them after each training step (that would fit very badly with my current code base,)?

I found two ways but I’m unsure about this:

  • Use the object store by creating a model model = ray.put(MyModel()) and keep a reference to it in the workers, so that at any time they can retrieve the up-to-date model by calling get()
    • Does backpropagation work correctly in this case? I.e. if from the main process I call model.get().forward(inputs), then loss = loss_function(inputs, targets),loss.backward() and then optimizer.step where optimizer = Adam(model.get().parameters())
  • Otherwise, I could create a remote actor that contains the model and performs the training, and evaluation workers can call get the model from it and use it in their evaluation

Which one is better? Are there other ways ?

Hi @fedetask!

Some questions for you:

  • Could you explain what your overall goal is?
  • Are you doing reinforcement learning?
  • Are you training your torch model on GPUs or CPUs?
  • Are you evaluating your model on GPUs or CPUs?
  • Do you need to do evaluation in parallel?

There are a few different technological approaches depending on the answers.

Sure!

I am integrating Ray into my RL implementation, which until now was running in a single process. The paradigm is the standard one, i.e. collect experience → update policy → collect new experience → update policy, etc. What I want to do is to parallelize/distribute experience collection, and keep the policy update (i.e. compute loss, backprop, gradient step) in a single process.

My code is made of a Trainer class that takes as input the env, a loss function instance, and an exploration policy instance. The Trainer does not access the models directly, as the loss is computed by the loss function class, and actions by the policy class. We could summarize it as follows:

class Trainer:

    def __init__(self, env, loss, policy_explore):
        self.env = env
        self.loss = loss
        self.policy_explore = policy_explore

        self.replay_buffer
        self.optimizer

    def train(training_steps: int, rollout_len: int, ):
        obs = env.reset()
        for step in range(training_steps):
            for step in range(rollout_len):
                action = self.policy_explore.act(obs)
                next_obs, rew, done, info = env.step(action)
                self.replay_buffer.remember(obs, action, reward, next_obs, done)
                obs = next_obs if not done else env.reset()
            
           batch = self.replay_buffer.sample()
           self.optimizer.zero_grad()
           loss = self.loss_function.compute(batch)
           self.optimizer.step()


class EpsilonGreedyPolicy():

    def __init__(q_net: torch.Module, epsilon: float):
        self.q_net = q_net
        self.epsilon = epsilon

    def act(obs):
        # Return random action with p = epsilon or self.q_net(obs).argmax() with p = 1 - epsilon


class DQNLoss:
    
    def __init__(q_net: Module, discount: float):
        self.q_net = q_net
        self.q_net_target = deepcopy(q_net)

    def compute(batch) -> Tensor:
        q_values = self.q_net(batch['observation']).gather(batch['actions'])
        target_values = # compute target values
        return ((q_values - target_values)**2).mean()


if __name__ == '__main__':
    env = # get environment
    q_net = MyModel()  # Extends torch.Module
    policy = EpsilonGreedyPolicy(q_net, 0.2)
    loss = DQNLoss(q_net, 0.99)

    trainer = Trainer(env, loss, policy)
    trainer.train()

Integrating Ray would likely entail creating some ray Actor which collects the rollout in Traner.train(), the main process collects all the results and adds them to the replay buffer, and performs the training step.

I now realize that it would be much more convenient to not keep any reference of the models in loss and policy classes, but take them as argument in the compute() and act() methods.

The result would be something like:

@ray.remote
class RolloutCollector:

    def __init__(policy, env):
        self.policy = policy
        self.env = env
        self._cur_obs = env.reset()

    def collect_rollout(q_net: torch.model, rollout_len: int):
        steps = []
        for i in range(rollout_len):
            action = self.policy(q_net, self.obs)
            next_obs, rew, done, info = self.env(action)
            steps.append((obs, next_obs, rew, done))
        return steps

and in the trainer I would instantiate several RolloutCollectors and call them by passing them the model in collect_rollout.remote(q_net, ...).

I actually think I answered myself by writing you this reply, but let me know if you see any issues or if you have any suggestions. My main trouble was that a reference to the torch model was kept around in many classes, and I couldn’t find a proper way to always keep it up-to-date without explicitly passing it around to workers every time, or breaking my code structure. I think what I described above makes more sense.

Just a few questions:

  1. Can I call a remote method of an actor by passing custom objects in the arguments? What are the serialization limits?
  2. What can the remote method return without incurring in serialization errors?
  3. What if I just keep the models in Ray object store? From any class that has a reference to them I could just do self.model.get() and always access the up-to-date model.
    • Does this break the backward() on the loss or optimizer step()? In other words, can such an operation happen directly on an object that resides in the object store, or would I need to a) retrieve the model with model.get(), b) optimized it, c) put the optimized model back into the object store?

From my limited knowledge of RL your approach looks fine. Feel free to ask @avnishn / @arturn (or create a post in the RLLib category) for more in-depth RL answers.

  1. Can I call a remote method of an actor by passing custom objects in the arguments? What are the serialization limits?

Yep! If you find any errors during object serialization you should file a GitHub issue. It’s pretty rare to get serialization errors – certain objects like exception traceback objects fail to serialize, but most things should just work. You should be able to serialize your torch model that way (and if not, then just send over the state_dict numpy arrays). Note that Arrow-compatible data incurs lower performance overhead than arbitrary Python objects due to zero-copy serialization.

In terms of performance, how big is your model? If it’s large, you may consider using collective communication primitives (such as GLOO for CPU-CPU or NCCL for GPU-GPU). This is an advanced use case and off the beaten path, so only go there if you need the performance.

  1. What can the remote method return without incurring in serialization errors?

This should be the same as task (actor method) arguments.

  1. What if I just keep the models in Ray object store? From any class that has a reference to them I could just do self.model.get() and always access the up-to-date model.
    • Does this break the backward() on the loss or optimizer step()? In other words, can such an operation happen directly on an object that resides in the object store, or would I need to a) retrieve the model with model.get(), b) optimized it, c) put the optimized model back into the object store?

The object store is immutable, so without nasty hacks this won’t work. It seems to me simplest to 1) optimize model in an actor, then 2) ship the optimized model to the various rollout workers, then 3) repeat 1-2 until convergence. Either the rollout workers themselves request the latest model or some coordinator actor (could be the trainer) sends the latest model to each rollout worker.

Hi @cade , hi @fedetask ,

Don’t have much to add here. Keep in mind that (de)serializing models is an expensive process. If you are writing your own RL codebase, sometimes you’ll want to do multiple rollouts under the same policy and synchronizing weights once for multiple rollouts might just be quicker.
Here is our implementation of what you are solving here.

1 Like