Code slows down when using prioritised replay buffer vs. vanilla replay buffer

I am using a distributed setup to train a DQN. When I used the following replay buffer, the code works perfectly fine:

@ray.remote
class ReplayBuffer:
    def __init__(self, capacity, batch_size=128):
        self.capacity = capacity
        self.buffer = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            self.buffer.append(experience)

    def sample(self):
        return random.sample(self.buffer, self.batch_size)

    def __len__(self):
        return len(self.buffer)

However, when I try to use the following prioritised buffer then the code slows down to a point where it is unusable:

@ray.remote
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.buffer = []
        self.pos = 0
        self.priorities = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            max_priority = max(self.priorities) if self.buffer else 1.0

            if len(self.buffer) < self.capacity:
                self.buffer.append(experience)
                self.priorities.append(max_priority)
            else:
                self.buffer[self.pos] = experience
                self.priorities[self.pos] = max_priority

            self.pos = (self.pos + 1) % self.capacity

    def sample(self):
        start = time.time()
        N = len(self.buffer)
        if N == self.capacity:
            priorities = np.array(self.priorities)
        else:
            priorities = np.array(self.priorities[:self.pos])

        self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
        sampling_probabilities = priorities ** self.alpha
        sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
        indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)

        experiences = [self.buffer[idx] for idx in indices]
        weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
        weights = weights / weights.max()
        end = time.time()
        print(f"sampling took {(end - start) / 60} minutes")
        return experiences, np.array(indices), weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)

I’ve tested the prioritised buffer code when not using ray and don’t notice any performance issues – it runs as I would expect it to, but I don’t know why this would slow down when using it in a distributed setting.

The method that runs my distributed workers looks like this:

    def run(self):
        state = self.env.reset()
        episode_reward = 0
        episode = 0
        ep_length = 0
        grad_steps = 0
        intermediate_memory = []  # this is what we will push to the buffer at once

        while grad_steps < self.num_grad_steps:
            ep_length += 1
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            intermediate_memory.append((state, action, reward, next_state, done))
            if len(intermediate_memory) >= self.push_size:
                self.replay_buffer.push.remote(intermediate_memory)
                intermediate_memory = []
                self.sync_with_param_server()
                grad_steps = ray.get(self.param_server.return_grad_steps.remote())
                time.sleep(60 * 5)
            episode_reward += reward

            if done:
                # print results locally
                # print(f"Episode {episode}: {episode_reward}")
                # print_status(self.env, time_step)

                # prepare new rollout
                episode += 1
                episode_reward = 0
                ep_length = 0
                next_state = self.env.reset()

            state = next_state

and in the learner, I interact with the buffer when sampling a batch and updating priorities in the experience replay method. I noticed that if I comment out the line where the workers push to the buffer the issue goes away (though obviously this is not a fix) – I thought the issue could be when having multiple things trying to interact with the buffer at once but I tried using some locks but this didn’t help.

Any help would be greatly appreciated!

Moving this to RL category

1 Like

Hi @direland3 , @cade,

This question is not really RLlib related.
You seem to be timing your code. The fact that sampling takes n minutes makes it look like what takes up most of the time are the sample() calls, correct?
There is nothing specific in that function that looks like it should take such amounts of time.
Maybe you should time a little more to find out whats causing this issue.

1 Like

Hi @arturn

I have profiled my code. The main difference appears to be using ray.get(). When testing for two grad steps, the result is the following:

  • with vanilla replay buffer, ray.get() is called twice for a total time of 250ms (0.9% of the total time taken)
  • with the prioritised replay buffer, it is again called twice but for a total time of 50,840ms (67.9% of the total time taken).

I would appreciate if someone could help me identify what is causing the huge discrepancy in time being taken for the ray.get() method in the prioritised replay buffer. Thanks.

Hi @direland3,

Do you have an example reproduction script you can share.

It is not possible to see what the issue might be with what you have shared so far.

My best guess is that you are having a resource contention issue. Last time I checked, default ray will queue multiple remote calls to the same actor and ensure that only one remote call runs at a time. For example if you have. multiple workers storing experiences at the same time it will not happen in parallel but rather one at a time sequentially. This is to prevent race conditions.

You can increase the number of parallel calls allowed on a remote function or actor. You will need to manage access to data structures to prevent data races.

Check out this documentation:
https://docs.ray.io/en/latest/ray-core/actors/async_api.html#threaded-actors

Hi @mannyv,

Thanks for the reply!

Here is the contents of the file with both vanilla and prioritised replay buffer:

from collections import deque
import gym
import copy
from networks import Net
import numpy as np
import random
import ray
import torch
from torch.nn import MSELoss
import time


class DQN:
    def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device="cuda" if torch.cuda.is_available() else "cpu"):

        self.gamma = gamma
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.memory = deque(maxlen=memory)
        self.batch_size = batch_size
        self.replay_start_size = replay_start_size
        self.exploration_min = exploration_min
        self.epsilon = 1
        self.exploration_decay = exploration_decay
        self.device = device

        self.net = Net(self.state_dim, self.action_dim).to(self.device)
        self.loss_fn = MSELoss()
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=0.0001)
        self.target_net = copy.deepcopy(self.net).to(self.device)
        self.loss = []
        self.count = 0
        self.tau = tau
        self.grad_steps = 0

    def act(self, state):
        if len(self.memory) < self.replay_start_size:
            return np.random.randint(0, self.action_dim)
        elif np.random.uniform() < self.epsilon:
            self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
            return np.random.randint(0, self.action_dim)
        else:
            state = torch.FloatTensor(state).to(self.device)
            with torch.no_grad():
                values = self.net(state)
            action = torch.argmax(values)
            return int(action)

    def greedy_act(self, state):
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            values = self.net(state)
        return values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))

    def experience_replay(self):
        if len(self.memory) < self.replay_start_size:
            return

        states, actions, rewards, next_states, dones = self.get_batch()

        # make forward pass of the network
        q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
        with torch.no_grad():
            targets = self.target_net(next_states)
            targets, _ = torch.max(targets, dim=1)
            targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
            targets = targets.squeeze()

        self.optimiser.zero_grad()
        loss = (q_vals - targets).pow(2).mean()
        loss.backward()
        self.optimiser.step()
        self.grad_steps += 1
        self.update_target()

    def get_batch(self):
        batch = random.sample(self.memory, self.batch_size)

        states = []
        actions = []
        rewards = []
        next_states = []
        dones = []
        for state, action, reward, next_state, done in batch:
            states.append(state)
            actions.append([action])
            rewards.append([reward])
            next_states.append(next_state)
            dones.append([done])

        return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device)

    def update_target(self):
        for real, target in zip(self.net.parameters(), self.target_net.parameters()):
            target.data.copy_(real.data * self.tau + target.data * (1 - self.tau))

    def save_model(self):
        torch.save(self.net, "dqn")


@ray.remote
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.buffer = []
        self.pos = 0
        self.priorities = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            max_priority = max(self.priorities) if self.buffer else 1.0

            if len(self.buffer) < self.capacity:
                self.buffer.append(experience)
                self.priorities.append(max_priority)
            else:
                self.buffer[self.pos] = experience
                self.priorities[self.pos] = max_priority

            self.pos = (self.pos + 1) % self.capacity

    def sample(self):
        N = len(self.buffer)
        if N == self.capacity:
            priorities = np.array(self.priorities)
        else:
            priorities = np.array(self.priorities[:self.pos])

        self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
        sampling_probabilities = priorities ** self.alpha
        sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
        indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)

        experiences = [self.buffer[idx] for idx in indices]
        weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
        weights = weights / weights.max()
        return experiences, np.array(indices), weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)


@ray.remote
class ReplayBuffer:
    def __init__(self, capacity, batch_size=128):
        self.capacity = capacity
        self.buffer = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            self.buffer.append(experience)

    def sample(self):
        return random.sample(self.buffer, self.batch_size)

    def __len__(self):
        return len(self.buffer)


class DistributedPDQN(DQN):
    def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, tau, device="cuda" if torch.cuda.is_available() else "cpu", beta=0.4, alpha=0.6, beta_increment_per_sampling=0.001,
                 num_workers=5, max_grad_steps=5e4, push_data_size=20, use_priority=True):
        self.env = gym.make('LunarLander-v2')
        self.use_priority = use_priority
        action_dim, state_dim = self.env.action_space.n, self.env.observation_space.shape[0]
        super(DistributedPDQN, self).__init__(gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device)
        if self.use_priority:
            self.memory = PrioritizedReplayBuffer.remote(memory, alpha, beta, beta_increment_per_sampling, batch_size)
        else:
            self.memory = ReplayBuffer.remote(memory, batch_size)
        self.param_server = ParameterServer.remote()
        params = dict(self.net.named_parameters())
        with torch.no_grad():
            for name in params:
                params[name] = params[name].cpu().numpy()
        self.param_server.define_param_list.remote(params)
        self.actors = [Actor.remote(state_dim, action_dim, exploration_decay, exploration_min, i + 1, self.memory, self.param_server, push_data_size, max_grad_steps) for i in range(num_workers)]
        self.max_grad_steps = max_grad_steps

        self.num_actions = action_dim
        self.test_scores = []

    def learn(self):
        data = []
        burnin_ep = 0
        while len(data) < 50000:
            burnin_ep += 1
            state = self.env.reset()
            score = 0
            done = False
            while not done:
                action = np.random.randint(0, self.num_actions)
                next_state, reward, done, _ = self.env.step(action)
                data.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
                state = next_state
                score += reward
            print(f"Burn-in episode {burnin_ep + 1}")
        self.memory.push.remote(data)
        del data

        # start the actor
        [a.run.remote() for a in self.actors]

        # start learning
        while self.grad_steps < self.max_grad_steps:
            self.experience_replay()
            with torch.no_grad():
                params = dict(self.net.named_parameters())
                for name in params:
                    params[name] = params[name].cpu().numpy()
            self.param_server.update_params.remote(params, self.grad_steps)
            print(f"{self.grad_steps} grad steps taken")
            if (self.grad_steps + 1) % 500 == 0:
                self.run_test_episode()
            if (self.grad_steps + 1) % 500 == 0:
                self.save_model()

    def run_test_episode(self):
        state = self.env.reset()
        score = 0
        done = False
        while not done:
            action = self.greedy_act(state)
            next_state, reward, done, _ = self.env.step(action)
            state = next_state
            score += reward
        self.test_scores.append(score)

    def get_batch(self):
        start = time.time()
        if self.use_priority:
            batch, idx, weights = ray.get(self.memory.sample.remote())
        else:
            batch = ray.get(self.memory.sample.remote())
            idx = []
            weights = []

        states = []
        actions = []
        rewards = []
        next_states = []
        dones = []
        for state, action, reward, next_state, done in batch:
            states.append(state)
            actions.append([action])
            rewards.append([reward])
            next_states.append(next_state)
            dones.append([done])

        end = time.time()
        print(f"getting batch took {(end - start) / 60} minutes")
        return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device), idx, \
               torch.FloatTensor(weights).to(self.device)

    def experience_replay(self):
        start = time.time()
        states, actions, rewards, next_states, dones, sample_idx, weights = self.get_batch()

        # make forward pass of the network
        q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
        with torch.no_grad():
            targets = self.target_net(next_states)
            targets, _ = torch.max(targets, dim=1)
            targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
            targets = targets.squeeze()

        td_error = (q_vals - targets).pow(2)
        if self.use_priority:
            priorities = td_error + 1e-5
            self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy())

        self.optimiser.zero_grad()
        if self.use_priority:
            loss = (td_error * weights).mean()
        else:
            loss = td_error.mean()
        loss.backward()
        self.optimiser.step()
        self.grad_steps += 1
        self.update_target()
        end = time.time()
        print(f"experience replay took {(end - start) / 60} minutes")


@ray.remote
class ParameterServer(object):
    def __init__(self):
        self.grad_steps = 0
        self.actor_params = None

    def define_param_list(self, actor_param_dict):
        self.actor_params = {}
        for name in actor_param_dict:
            self.actor_params[name] = actor_param_dict[name]

    def update_params(self, actor_params, grad_steps):
        for name in actor_params:
            self.actor_params[name] = actor_params[name]
        self.grad_steps = grad_steps

    def return_params(self):
        return self.actor_params

    def return_grad_steps(self):
        return self.grad_steps


@ray.remote
class Actor(object):

    def __init__(self, state_dim, action_dim, exploration_decay, exploration_min, worker_id=None, replay_buffer=None, param_server=None, push_size=20, num_grad_steps=1e6):
        self.worker_id = worker_id

        self.env = gym.make('LunarLander-v2')
        self.net = Net(state_dim, action_dim)

        # get ray_remote objects; centralized buffer and parameter server
        self.replay_buffer = replay_buffer
        self.param_server = param_server
        self.push_size = push_size  # this is how much data we need until we push to the centralised buffer
        self.num_grad_steps = num_grad_steps
        self.epsilon = 1
        self.exploration_decay = exploration_decay
        self.exploration_min = exploration_min
        self.action_dim = action_dim

    def act(self, state):
        if np.random.uniform() < self.epsilon:
            self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
            return np.random.randint(0, self.action_dim)
        else:
            state = torch.FloatTensor(state)
            with torch.no_grad():
                values = self.net(state)
            action = torch.argmax(values)
            return int(action)

    def sync_with_param_server(self):
        new_actor_params = ray.get(self.param_server.return_params.remote())
        for param in new_actor_params:
            new_actor_params[param] = torch.from_numpy(new_actor_params[param]).float()

        self.net.load_state_dict(new_actor_params)

    def run(self):
        state = self.env.reset()
        episode_reward = 0
        episode = 0
        ep_length = 0
        grad_steps = 0
        intermediate_memory = []  # this is what we will push to the buffer at once

        while grad_steps < self.num_grad_steps:
            ep_length += 1
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            intermediate_memory.append((state, action, reward, next_state, done))
            if len(intermediate_memory) >= self.push_size:
                self.replay_buffer.push.remote(intermediate_memory)
                intermediate_memory = []
                self.sync_with_param_server()
                grad_steps = ray.get(self.param_server.return_grad_steps.remote())
                # time.sleep(60 * 5)
            episode_reward += reward

            if done:
                # print results locally
                # print(f"Episode {episode}: {episode_reward}")
                # print_status(self.env, time_step)

                # prepare new rollout
                episode += 1
                episode_reward = 0
                ep_length = 0
                next_state = self.env.reset()

            state = next_state

You can then run an instance of this with:

if __name__ == "__main__":
    ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=5, push_data_size=20, use_priority=True, max_grad_steps=2)
    ddqn.learn()

switching the use_priority flag to False will use the vanilla replay buffer instead of the prioritised buffer.

From what you have said, I’m not sure why this would cause a problem with the prioritised buffer and not the vanilla buffer? both push experience in a similar manner, so it’s not clear to me why the prioritised buffer would be causing such a big slow down in time.

Hi @direland3,

You did not include your network so I just made a two layer layer Dense network [64,64].

I think you need to synchronize your remote calls with a ray.get. Perhaps some do not need to be I am not sure. That is something you can experiment with.

❯ python scripts/prioritized-replay-slowdown.py
No priority learn took 7.397060149
Priority learn took 5.913368073999999

Changes:

@@ -196,7 +196,7 @@ class DistributedPDQN(DQN):
         with torch.no_grad():
             for name in params:
                 params[name] = params[name].cpu().numpy()
-        self.param_server.define_param_list.remote(params)
+        ray.get(self.param_server.define_param_list.remote(params))
         self.actors = [Actor.remote(state_dim, action_dim, exploration_decay, exploration_min, i + 1, self.memory, self.param_server, push_data_size, max_grad_steps) for i in range(num_workers)]
         self.max_grad_steps = max_grad_steps
 
@@ -218,7 +218,7 @@ class DistributedPDQN(DQN):
                 state = next_state
                 score += reward
             print(f"Burn-in episode {burnin_ep + 1}")
-        self.memory.push.remote(data)
+        ray.get(self.memory.push.remote(data))
         del data
 
         # start the actor
@@ -231,7 +231,7 @@ class DistributedPDQN(DQN):
                 params = dict(self.net.named_parameters())
                 for name in params:
                     params[name] = params[name].cpu().numpy()
-            self.param_server.update_params.remote(params, self.grad_steps)
+            ray.get(self.param_server.update_params.remote(params, self.grad_steps))
             print(f"{self.grad_steps} grad steps taken")
             if (self.grad_steps + 1) % 500 == 0:
                 self.run_test_episode()
@@ -290,7 +290,7 @@ class DistributedPDQN(DQN):
         td_error = (q_vals - targets).pow(2)
         if self.use_priority:
             priorities = td_error + 1e-5
-            self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy())
+            ray.get(self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy()))
 
         self.optimiser.zero_grad()
         if self.use_priority:
@@ -379,7 +379,7 @@ class Actor(object):
             next_state, reward, done, _ = self.env.step(action)
             intermediate_memory.append((state, action, reward, next_state, done))
             if len(intermediate_memory) >= self.push_size:
-                self.replay_buffer.push.remote(intermediate_memory)
+                ray.get(self.replay_buffer.push.remote(intermediate_memory))
                 intermediate_memory = []
                 self.sync_with_param_server()
                 grad_steps = ray.get(self.param_server.return_grad_steps.remote())
@@ -402,11 +402,11 @@ class Actor(object):
 
 
+ if __name__ == "__main__":
+    start=time.process_time()
+    ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=5, push_data_size=20, use_priority=False, max_grad_steps=2)
+    ddqn.learn()
+    end=time.process_time()
+    print(f"No priority learn took {end-start}")
+    start=time.process_time()
+    ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=5, push_data_size=20, use_priority=True, max_grad_steps=2)
+    ddqn.learn()
+    end=time.process_time()
+    print(f"Priority learn took {end-start}")
1 Like

Hi @mannyv,

Thanks! I took your advice and wrapped what you suggested in a ray.get(). It didn’t quite work when I added more grad steps than 2 but then I just wrapped any bit of code that was interacting with an object in a different process with a ray.get() and it seems to work well now!

Just for my own understanding, why would this be necessary for the prioritised buffer but not the vanilla buffer? I’m wondering if my understanding of what ray.get() does is wrong.