Memory leakage using distributed workers

I appear to have a memory leakage when using distributed training for a DQN. Below is the code:

from collections import deque
import gym
import copy
from networks import Net
import numpy as np
import random
import ray
import torch
import torch.nn as nn
from torch.nn import MSELoss


class Net(nn.Module):

    def __init__(self, state_dim, action_dim):

        super(Net, self).__init__()

        self.input_layer = nn.Linear(state_dim, 128)
        self.h1 = nn.Linear(128, 128)
        self.h2 = nn.Linear(128, 128)
        self.output_layer = nn.Linear(128, action_dim)

    def forward(self, h):
        h = torch.relu(self.input_layer(h))
        h = torch.relu(self.h1(h))
        h = torch.relu(self.h2(h))
        h = self.output_layer(h)
        return h



class DQN:
    def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device="cuda" if torch.cuda.is_available() else "cpu"):

        self.gamma = gamma
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.memory = deque(maxlen=memory)
        self.batch_size = batch_size
        self.replay_start_size = replay_start_size
        self.exploration_min = exploration_min
        self.epsilon = 1
        self.exploration_decay = exploration_decay
        self.device = device

        self.net = Net(self.state_dim, self.action_dim).to(self.device)
        self.loss_fn = MSELoss()
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=0.0001)
        self.target_net = copy.deepcopy(self.net).to(self.device)
        self.loss = []
        self.count = 0
        self.tau = tau
        self.grad_steps = 0

    def act(self, state):
        if len(self.memory) < self.replay_start_size:
            return np.random.randint(0, self.action_dim)
        elif np.random.uniform() < self.epsilon:
            self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
            return np.random.randint(0, self.action_dim)
        else:
            state = torch.FloatTensor(state).to(self.device)
            with torch.no_grad():
                values = self.net(state)
            action = torch.argmax(values)
            return int(action)

    def greedy_act(self, state):
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            values = self.net(state)
        return values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))

    def experience_replay(self):
        if len(self.memory) < self.replay_start_size:
            return

        states, actions, rewards, next_states, dones = self.get_batch()

        # make forward pass of the network
        q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
        with torch.no_grad():
            targets = self.target_net(next_states)
            targets, _ = torch.max(targets, dim=1)
            targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
            targets = targets.squeeze()

        self.optimiser.zero_grad()
        loss = (q_vals - targets).pow(2).mean()
        loss.backward()
        self.optimiser.step()
        self.grad_steps += 1
        self.update_target()

    def get_batch(self):
        batch = random.sample(self.memory, self.batch_size)

        states = []
        actions = []
        rewards = []
        next_states = []
        dones = []
        for state, action, reward, next_state, done in batch:
            states.append(state)
            actions.append([action])
            rewards.append([reward])
            next_states.append(next_state)
            dones.append([done])

        return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device)

    def update_target(self):
        for real, target in zip(self.net.parameters(), self.target_net.parameters()):
            target.data.copy_(real.data * self.tau + target.data * (1 - self.tau))

    def save_model(self):
        torch.save(self.net, "dqn")


@ray.remote
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.buffer = []
        self.pos = 0
        self.priorities = []
        self.batch_size = batch_size

    def push(self, data):
        max_priority = max(self.priorities) if self.buffer else 1.0
        for experience in data:
            if len(self.buffer) < self.capacity:
                self.buffer.append(experience)
                self.priorities.append(max_priority)
            else:
                self.buffer[self.pos] = experience
                self.priorities[self.pos] = max_priority

            self.pos = (self.pos + 1) % self.capacity

    def sample(self):
        N = len(self.buffer)
        if N == self.capacity:
            priorities = np.array(self.priorities)
        else:
            priorities = np.array(self.priorities[:self.pos])

        self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
        sampling_probabilities = priorities ** self.alpha
        sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
        indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)

        experiences = [self.buffer[idx] for idx in indices]
        weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
        weights = weights / weights.max()
        return experiences, np.array(indices), weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)


@ray.remote
class ReplayBuffer:
    def __init__(self, capacity, batch_size=128):
        self.capacity = capacity
        self.buffer = []
        self.batch_size = batch_size

    def push(self, data):
        for experience in data:
            self.buffer.append(experience)

    def sample(self):
        return random.sample(self.buffer, self.batch_size)

    def __len__(self):
        return len(self.buffer)


class DistributedPDQN(DQN):
    def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, tau, device="cuda" if torch.cuda.is_available() else "cpu", beta=0.4, alpha=0.6, beta_increment_per_sampling=0.001,
                 num_workers=5, max_grad_steps=5e4, push_data_size=20, use_priority=True):
        self.env = gym.make('LunarLander-v2')
        self.use_priority = use_priority
        action_dim, state_dim = self.env.action_space.n, self.env.observation_space.shape[0]
        super(DistributedPDQN, self).__init__(gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device)
        if self.use_priority:
            self.memory = PrioritizedReplayBuffer.remote(memory, alpha, beta, beta_increment_per_sampling, batch_size)
        else:
            self.memory = ReplayBuffer.remote(memory, batch_size)
        self.param_server = ParameterServer.remote()
        params = dict(self.net.named_parameters())
        with torch.no_grad():
            for name in params:
                params[name] = params[name].cpu().numpy()
        ray.get(self.param_server.define_param_list.remote(params))
        self.actors = [Actor.remote(state_dim, action_dim, exploration_decay, exploration_min, i + 1, self.memory, self.param_server, push_data_size, max_grad_steps) for i in range(num_workers)]
        self.max_grad_steps = max_grad_steps

        self.num_actions = action_dim
        self.test_scores = []

    def learn(self):
        data = []
        burnin_ep = 0
        while len(data) < 50000:
            burnin_ep += 1
            state = self.env.reset()
            score = 0
            done = False
            while not done:
                action = np.random.randint(0, self.num_actions)
                next_state, reward, done, _ = self.env.step(action)
                data.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
                state = next_state
                score += reward
            print(f"Burn-in episode {burnin_ep + 1}")
        ray.get(self.memory.push.remote(data))
        del data

        # start the actor
        [a.run.remote() for a in self.actors]

        # start learning
        while self.grad_steps < self.max_grad_steps:
            self.experience_replay()
            with torch.no_grad():
                params = dict(self.net.named_parameters())
                for name in params:
                    params[name] = params[name].cpu().numpy()
            ray.get(self.param_server.update_params.remote(params, self.grad_steps))
            print(f"{self.grad_steps} grad steps taken")
            if (self.grad_steps + 1) % 500 == 0:
                self.run_test_episode()
            if (self.grad_steps + 1) % 500 == 0:
                self.save_model()

    def run_test_episode(self):
        state = self.env.reset()
        score = 0
        done = False
        while not done:
            action = self.greedy_act(state)
            next_state, reward, done, _ = self.env.step(action)
            state = next_state
            score += reward
        self.test_scores.append(score)

    def get_batch(self):
        if self.use_priority:
            batch, idx, weights = ray.get(self.memory.sample.remote())
        else:
            batch = ray.get(self.memory.sample.remote())
            idx = []
            weights = []

        states = []
        actions = []
        rewards = []
        next_states = []
        dones = []
        for state, action, reward, next_state, done in batch:
            states.append(state)
            actions.append([action])
            rewards.append([reward])
            next_states.append(next_state)
            dones.append([done])

        return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device), idx, \
               torch.FloatTensor(weights).to(self.device)

    def experience_replay(self):
        states, actions, rewards, next_states, dones, sample_idx, weights = self.get_batch()

        # make forward pass of the network
        q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
        with torch.no_grad():
            targets = self.target_net(next_states)
            targets, _ = torch.max(targets, dim=1)
            targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
            targets = targets.squeeze()

        td_error = (q_vals - targets).pow(2)
        if self.use_priority:
            priorities = td_error + 1e-5
            ray.get(self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy()))

        self.optimiser.zero_grad()
        if self.use_priority:
            loss = (td_error * weights).mean()
        else:
            loss = td_error.mean()
        loss.backward()
        self.optimiser.step()
        self.grad_steps += 1
        self.update_target()


@ray.remote
class ParameterServer(object):
    def __init__(self):
        self.grad_steps = 0
        self.actor_params = None

    def define_param_list(self, actor_param_dict):
        self.actor_params = {}
        for name in actor_param_dict:
            self.actor_params[name] = actor_param_dict[name]

    def update_params(self, actor_params, grad_steps):
        for name in actor_params:
            self.actor_params[name] = actor_params[name]
        self.grad_steps = grad_steps

    def return_params(self):
        return self.actor_params

    def return_grad_steps(self):
        return self.grad_steps


@ray.remote
class Actor(object):

    def __init__(self, state_dim, action_dim, exploration_decay, exploration_min, worker_id=None, replay_buffer=None, param_server=None, push_size=20, num_grad_steps=1e6):
        self.worker_id = worker_id

        self.env = gym.make('LunarLander-v2')
        self.net = Net(state_dim, action_dim)

        # get ray_remote objects; centralized buffer and parameter server
        self.replay_buffer = replay_buffer
        self.param_server = param_server
        self.push_size = push_size  # this is how much data we need until we push to the centralised buffer
        self.num_grad_steps = num_grad_steps
        self.epsilon = 1
        self.exploration_decay = exploration_decay
        self.exploration_min = exploration_min
        self.action_dim = action_dim

    def act(self, state):
        if np.random.uniform() < self.epsilon:
            self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
            return np.random.randint(0, self.action_dim)
        else:
            state = torch.FloatTensor(state)
            with torch.no_grad():
                values = self.net(state)
            action = torch.argmax(values)
            return int(action)

    def sync_with_param_server(self):
        new_actor_params = ray.get(self.param_server.return_params.remote())
        for param in new_actor_params:
            new_actor_params[param] = torch.from_numpy(new_actor_params[param]).float()

        self.net.load_state_dict(new_actor_params)

    def run(self):
        state = self.env.reset()
        episode_reward = 0
        episode = 0
        ep_length = 0
        grad_steps = 0
        intermediate_memory = []  # this is what we will push to the buffer at once

        while grad_steps < self.num_grad_steps:
            ep_length += 1
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            intermediate_memory.append((state, action, reward, next_state, done))
            if len(intermediate_memory) >= self.push_size:
                ray.get(self.replay_buffer.push.remote(intermediate_memory))
                intermediate_memory = []
                self.sync_with_param_server()
                grad_steps = ray.get(self.param_server.return_grad_steps.remote())
            episode_reward += reward

            if done:
                # print results locally
                # print(f"Episode {episode}: {episode_reward}")
                # print_status(self.env, time_step)

                # prepare new rollout
                episode += 1
                episode_reward = 0
                ep_length = 0
                next_state = self.env.reset()

            state = next_state

if __name__ == "__main__":
    ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=4, push_data_size=20, use_priority=False, max_grad_steps=1e6)
    ddqn.learn()

I’ve attached a screen shot of the memory usage for this process. Note that the issue also happens when using a Linux system. It looks like it is the main process that is consuming the most memory.
I’m very confused what would be causing this, so any help would be greatly appreciated! thanks.

cc @kourosh what’s the best way to debug OOM from rllib?

1 Like

Hi @direland3, I would recommend taking a look at test_memory_leaks.py to see if you can use that code to pin down where the issue is coming from.

1 Like

I think I have found the error. Stupidly, when I made the vanilla replay buffer I copied the Prioritised Buffer and made the adjustments there. however, the buffer was originally a list when it should’ve been a deque. This essentially means the list just keeps growing rather than be a fixed size and that is causing the memory leak – nothing to do with Ray.

I don’t know how to delete the post otherwise I would.

Thank you for taking a look anyway though!

2 Likes