My prioritised replay buffer slows down my code massively

I am using a distributed setup to train a DQN. When I used the following replay buffer, the code works perfectly fine:

@ray.remote
class ReplayBuffer:
    def __init__(self, capacity, batch_size=128):
        self.capacity = capacity
        self.buffer = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            self.buffer.append(experience)

    def sample(self):
        return random.sample(self.buffer, self.batch_size)

    def __len__(self):
        return len(self.buffer)

However, when I try to use the following prioritised buffer then the code slows down to a point where it is unusable:

@ray.remote
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.buffer = []
        self.pos = 0
        self.priorities = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            max_priority = max(self.priorities) if self.buffer else 1.0

            if len(self.buffer) < self.capacity:
                self.buffer.append(experience)
                self.priorities.append(max_priority)
            else:
                self.buffer[self.pos] = experience
                self.priorities[self.pos] = max_priority

            self.pos = (self.pos + 1) % self.capacity

    def sample(self):
        start = time.time()
        N = len(self.buffer)
        if N == self.capacity:
            priorities = np.array(self.priorities)
        else:
            priorities = np.array(self.priorities[:self.pos])

        self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
        sampling_probabilities = priorities ** self.alpha
        sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
        indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)

        experiences = [self.buffer[idx] for idx in indices]
        weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
        weights = weights / weights.max()
        end = time.time()
        print(f"sampling took {(end - start) / 60} minutes")
        return experiences, np.array(indices), weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)

I’ve tested the prioritised buffer code when not using ray and don’t notice any performance issues – it runs as I would expect it to, but I don’t know why this would slow down when using it in a distributed setting.

The method that runs my distributed workers looks like this:

    def run(self):
        state = self.env.reset()
        episode_reward = 0
        episode = 0
        ep_length = 0
        grad_steps = 0
        intermediate_memory = []  # this is what we will push to the buffer at once

        while grad_steps < self.num_grad_steps:
            ep_length += 1
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            intermediate_memory.append((state, action, reward, next_state, done))
            if len(intermediate_memory) >= self.push_size:
                self.replay_buffer.push.remote(intermediate_memory)
                intermediate_memory = []
                self.sync_with_param_server()
                grad_steps = ray.get(self.param_server.return_grad_steps.remote())
                time.sleep(60 * 5)
            episode_reward += reward

            if done:
                # print results locally
                # print(f"Episode {episode}: {episode_reward}")
                # print_status(self.env, time_step)

                # prepare new rollout
                episode += 1
                episode_reward = 0
                ep_length = 0
                next_state = self.env.reset()

            state = next_state

and in the learner, I interact with the buffer when sampling a batch and updating priorities in the experience replay method. I noticed that if I comment out the line where the workers push to the buffer the issue goes away (though obviously this is not a fix) – I thought the issue could be when having multiple things trying to interact with the buffer at once but I tried using some locks but this didn’t help.

Any help would be greatly appreciated!

Hey @direland3 ,
Can you compare the run-time of an RLlib implemented DQN with the two replay buffers to see if the problem persists there?

@kourosh,

He resolved his issue in the thread linked below.

I think the issue ended up being race conditions due to remote calls in a loop. The fix was to use ray.get as a synchronization barrier.

I have a feeling there might be something more fundamental going on for the core team to investigate but I cannot point to something more specific unfortunately.

2 Likes