Sure!
I am integrating Ray into my RL implementation, which until now was running in a single process. The paradigm is the standard one, i.e. collect experience → update policy → collect new experience → update policy, etc. What I want to do is to parallelize/distribute experience collection, and keep the policy update (i.e. compute loss, backprop, gradient step) in a single process.
My code is made of a Trainer class that takes as input the env, a loss function instance, and an exploration policy instance. The Trainer does not access the models directly, as the loss is computed by the loss function class, and actions by the policy class. We could summarize it as follows:
class Trainer:
def __init__(self, env, loss, policy_explore):
self.env = env
self.loss = loss
self.policy_explore = policy_explore
self.replay_buffer
self.optimizer
def train(training_steps: int, rollout_len: int, ):
obs = env.reset()
for step in range(training_steps):
for step in range(rollout_len):
action = self.policy_explore.act(obs)
next_obs, rew, done, info = env.step(action)
self.replay_buffer.remember(obs, action, reward, next_obs, done)
obs = next_obs if not done else env.reset()
batch = self.replay_buffer.sample()
self.optimizer.zero_grad()
loss = self.loss_function.compute(batch)
self.optimizer.step()
class EpsilonGreedyPolicy():
def __init__(q_net: torch.Module, epsilon: float):
self.q_net = q_net
self.epsilon = epsilon
def act(obs):
# Return random action with p = epsilon or self.q_net(obs).argmax() with p = 1 - epsilon
class DQNLoss:
def __init__(q_net: Module, discount: float):
self.q_net = q_net
self.q_net_target = deepcopy(q_net)
def compute(batch) -> Tensor:
q_values = self.q_net(batch['observation']).gather(batch['actions'])
target_values = # compute target values
return ((q_values - target_values)**2).mean()
if __name__ == '__main__':
env = # get environment
q_net = MyModel() # Extends torch.Module
policy = EpsilonGreedyPolicy(q_net, 0.2)
loss = DQNLoss(q_net, 0.99)
trainer = Trainer(env, loss, policy)
trainer.train()
Integrating Ray would likely entail creating some ray Actor which collects the rollout in Traner.train()
, the main process collects all the results and adds them to the replay buffer, and performs the training step.
I now realize that it would be much more convenient to not keep any reference of the models in loss and policy classes, but take them as argument in the compute()
and act()
methods.
The result would be something like:
@ray.remote
class RolloutCollector:
def __init__(policy, env):
self.policy = policy
self.env = env
self._cur_obs = env.reset()
def collect_rollout(q_net: torch.model, rollout_len: int):
steps = []
for i in range(rollout_len):
action = self.policy(q_net, self.obs)
next_obs, rew, done, info = self.env(action)
steps.append((obs, next_obs, rew, done))
return steps
and in the trainer I would instantiate several RolloutCollectors
and call them by passing them the model in collect_rollout.remote(q_net, ...)
.
I actually think I answered myself by writing you this reply, but let me know if you see any issues or if you have any suggestions. My main trouble was that a reference to the torch model was kept around in many classes, and I couldn’t find a proper way to always keep it up-to-date without explicitly passing it around to workers every time, or breaking my code structure. I think what I described above makes more sense.
Just a few questions:
- Can I call a remote method of an actor by passing custom objects in the arguments? What are the serialization limits?
- What can the remote method return without incurring in serialization errors?
- What if I just keep the models in Ray object store? From any class that has a reference to them I could just do self.model.get() and always access the up-to-date model.
- Does this break the
backward()
on the loss or optimizer step()
? In other words, can such an operation happen directly on an object that resides in the object store, or would I need to a) retrieve the model with model.get()
, b) optimized it, c) put the optimized model back into the object store?