I am using a distributed setup to train a DQN. When I used the following replay buffer, the code works perfectly fine:
@ray.remote
class ReplayBuffer:
def __init__(self, capacity, batch_size=128):
self.capacity = capacity
self.buffer = []
self.batch_size = batch_size
def push(self, data):
for experience in data:
self.buffer.append(experience)
def sample(self):
return random.sample(self.buffer, self.batch_size)
def __len__(self):
return len(self.buffer)
However, when I try to use the following prioritised buffer then the code slows down to a point where it is unusable:
@ray.remote
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.beta_increment_per_sampling = beta_increment_per_sampling
self.buffer = []
self.pos = 0
self.priorities = []
self.batch_size = batch_size
def push(self, data):
for experience in data:
max_priority = max(self.priorities) if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.priorities.append(max_priority)
else:
self.buffer[self.pos] = experience
self.priorities[self.pos] = max_priority
self.pos = (self.pos + 1) % self.capacity
def sample(self):
start = time.time()
N = len(self.buffer)
if N == self.capacity:
priorities = np.array(self.priorities)
else:
priorities = np.array(self.priorities[:self.pos])
self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
sampling_probabilities = priorities ** self.alpha
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)
experiences = [self.buffer[idx] for idx in indices]
weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
weights = weights / weights.max()
end = time.time()
print(f"sampling took {(end - start) / 60} minutes")
return experiences, np.array(indices), weights
def update_priorities(self, indices, priorities):
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return len(self.buffer)
I’ve tested the prioritised buffer code when not using ray and don’t notice any performance issues – it runs as I would expect it to, but I don’t know why this would slow down when using it in a distributed setting.
The method that runs my distributed workers looks like this:
def run(self):
state = self.env.reset()
episode_reward = 0
episode = 0
ep_length = 0
grad_steps = 0
intermediate_memory = [] # this is what we will push to the buffer at once
while grad_steps < self.num_grad_steps:
ep_length += 1
action = self.act(state)
next_state, reward, done, _ = self.env.step(action)
intermediate_memory.append((state, action, reward, next_state, done))
if len(intermediate_memory) >= self.push_size:
self.replay_buffer.push.remote(intermediate_memory)
intermediate_memory = []
self.sync_with_param_server()
grad_steps = ray.get(self.param_server.return_grad_steps.remote())
time.sleep(60 * 5)
episode_reward += reward
if done:
# print results locally
# print(f"Episode {episode}: {episode_reward}")
# print_status(self.env, time_step)
# prepare new rollout
episode += 1
episode_reward = 0
ep_length = 0
next_state = self.env.reset()
state = next_state
and in the learner, I interact with the buffer when sampling a batch and updating priorities in the experience replay method. I noticed that if I comment out the line where the workers push to the buffer the issue goes away (though obviously this is not a fix) – I thought the issue could be when having multiple things trying to interact with the buffer at once but I tried using some locks but this didn’t help.
Any help would be greatly appreciated!