I appear to have a memory leakage when using distributed training for a DQN. Below is the code:
from collections import deque
import gym
import copy
from networks import Net
import numpy as np
import random
import ray
import torch
import torch.nn as nn
from torch.nn import MSELoss
class Net(nn.Module):
def __init__(self, state_dim, action_dim):
super(Net, self).__init__()
self.input_layer = nn.Linear(state_dim, 128)
self.h1 = nn.Linear(128, 128)
self.h2 = nn.Linear(128, 128)
self.output_layer = nn.Linear(128, action_dim)
def forward(self, h):
h = torch.relu(self.input_layer(h))
h = torch.relu(self.h1(h))
h = torch.relu(self.h2(h))
h = self.output_layer(h)
return h
class DQN:
def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device="cuda" if torch.cuda.is_available() else "cpu"):
self.gamma = gamma
self.action_dim = action_dim
self.state_dim = state_dim
self.memory = deque(maxlen=memory)
self.batch_size = batch_size
self.replay_start_size = replay_start_size
self.exploration_min = exploration_min
self.epsilon = 1
self.exploration_decay = exploration_decay
self.device = device
self.net = Net(self.state_dim, self.action_dim).to(self.device)
self.loss_fn = MSELoss()
self.optimiser = torch.optim.Adam(self.net.parameters(), lr=0.0001)
self.target_net = copy.deepcopy(self.net).to(self.device)
self.loss = []
self.count = 0
self.tau = tau
self.grad_steps = 0
def act(self, state):
if len(self.memory) < self.replay_start_size:
return np.random.randint(0, self.action_dim)
elif np.random.uniform() < self.epsilon:
self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
return np.random.randint(0, self.action_dim)
else:
state = torch.FloatTensor(state).to(self.device)
with torch.no_grad():
values = self.net(state)
action = torch.argmax(values)
return int(action)
def greedy_act(self, state):
state = torch.FloatTensor(state).to(self.device)
with torch.no_grad():
values = self.net(state)
return values.argmax().item()
def remember(self, state, action, reward, next_state, done):
self.memory.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
def experience_replay(self):
if len(self.memory) < self.replay_start_size:
return
states, actions, rewards, next_states, dones = self.get_batch()
# make forward pass of the network
q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
with torch.no_grad():
targets = self.target_net(next_states)
targets, _ = torch.max(targets, dim=1)
targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
targets = targets.squeeze()
self.optimiser.zero_grad()
loss = (q_vals - targets).pow(2).mean()
loss.backward()
self.optimiser.step()
self.grad_steps += 1
self.update_target()
def get_batch(self):
batch = random.sample(self.memory, self.batch_size)
states = []
actions = []
rewards = []
next_states = []
dones = []
for state, action, reward, next_state, done in batch:
states.append(state)
actions.append([action])
rewards.append([reward])
next_states.append(next_state)
dones.append([done])
return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device)
def update_target(self):
for real, target in zip(self.net.parameters(), self.target_net.parameters()):
target.data.copy_(real.data * self.tau + target.data * (1 - self.tau))
def save_model(self):
torch.save(self.net, "dqn")
@ray.remote
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.beta_increment_per_sampling = beta_increment_per_sampling
self.buffer = []
self.pos = 0
self.priorities = []
self.batch_size = batch_size
def push(self, data):
max_priority = max(self.priorities) if self.buffer else 1.0
for experience in data:
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.priorities.append(max_priority)
else:
self.buffer[self.pos] = experience
self.priorities[self.pos] = max_priority
self.pos = (self.pos + 1) % self.capacity
def sample(self):
N = len(self.buffer)
if N == self.capacity:
priorities = np.array(self.priorities)
else:
priorities = np.array(self.priorities[:self.pos])
self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
sampling_probabilities = priorities ** self.alpha
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)
experiences = [self.buffer[idx] for idx in indices]
weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
weights = weights / weights.max()
return experiences, np.array(indices), weights
def update_priorities(self, indices, priorities):
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return len(self.buffer)
@ray.remote
class ReplayBuffer:
def __init__(self, capacity, batch_size=128):
self.capacity = capacity
self.buffer = []
self.batch_size = batch_size
def push(self, data):
for experience in data:
self.buffer.append(experience)
def sample(self):
return random.sample(self.buffer, self.batch_size)
def __len__(self):
return len(self.buffer)
class DistributedPDQN(DQN):
def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, tau, device="cuda" if torch.cuda.is_available() else "cpu", beta=0.4, alpha=0.6, beta_increment_per_sampling=0.001,
num_workers=5, max_grad_steps=5e4, push_data_size=20, use_priority=True):
self.env = gym.make('LunarLander-v2')
self.use_priority = use_priority
action_dim, state_dim = self.env.action_space.n, self.env.observation_space.shape[0]
super(DistributedPDQN, self).__init__(gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device)
if self.use_priority:
self.memory = PrioritizedReplayBuffer.remote(memory, alpha, beta, beta_increment_per_sampling, batch_size)
else:
self.memory = ReplayBuffer.remote(memory, batch_size)
self.param_server = ParameterServer.remote()
params = dict(self.net.named_parameters())
with torch.no_grad():
for name in params:
params[name] = params[name].cpu().numpy()
ray.get(self.param_server.define_param_list.remote(params))
self.actors = [Actor.remote(state_dim, action_dim, exploration_decay, exploration_min, i + 1, self.memory, self.param_server, push_data_size, max_grad_steps) for i in range(num_workers)]
self.max_grad_steps = max_grad_steps
self.num_actions = action_dim
self.test_scores = []
def learn(self):
data = []
burnin_ep = 0
while len(data) < 50000:
burnin_ep += 1
state = self.env.reset()
score = 0
done = False
while not done:
action = np.random.randint(0, self.num_actions)
next_state, reward, done, _ = self.env.step(action)
data.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
state = next_state
score += reward
print(f"Burn-in episode {burnin_ep + 1}")
ray.get(self.memory.push.remote(data))
del data
# start the actor
[a.run.remote() for a in self.actors]
# start learning
while self.grad_steps < self.max_grad_steps:
self.experience_replay()
with torch.no_grad():
params = dict(self.net.named_parameters())
for name in params:
params[name] = params[name].cpu().numpy()
ray.get(self.param_server.update_params.remote(params, self.grad_steps))
print(f"{self.grad_steps} grad steps taken")
if (self.grad_steps + 1) % 500 == 0:
self.run_test_episode()
if (self.grad_steps + 1) % 500 == 0:
self.save_model()
def run_test_episode(self):
state = self.env.reset()
score = 0
done = False
while not done:
action = self.greedy_act(state)
next_state, reward, done, _ = self.env.step(action)
state = next_state
score += reward
self.test_scores.append(score)
def get_batch(self):
if self.use_priority:
batch, idx, weights = ray.get(self.memory.sample.remote())
else:
batch = ray.get(self.memory.sample.remote())
idx = []
weights = []
states = []
actions = []
rewards = []
next_states = []
dones = []
for state, action, reward, next_state, done in batch:
states.append(state)
actions.append([action])
rewards.append([reward])
next_states.append(next_state)
dones.append([done])
return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device), idx, \
torch.FloatTensor(weights).to(self.device)
def experience_replay(self):
states, actions, rewards, next_states, dones, sample_idx, weights = self.get_batch()
# make forward pass of the network
q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
with torch.no_grad():
targets = self.target_net(next_states)
targets, _ = torch.max(targets, dim=1)
targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
targets = targets.squeeze()
td_error = (q_vals - targets).pow(2)
if self.use_priority:
priorities = td_error + 1e-5
ray.get(self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy()))
self.optimiser.zero_grad()
if self.use_priority:
loss = (td_error * weights).mean()
else:
loss = td_error.mean()
loss.backward()
self.optimiser.step()
self.grad_steps += 1
self.update_target()
@ray.remote
class ParameterServer(object):
def __init__(self):
self.grad_steps = 0
self.actor_params = None
def define_param_list(self, actor_param_dict):
self.actor_params = {}
for name in actor_param_dict:
self.actor_params[name] = actor_param_dict[name]
def update_params(self, actor_params, grad_steps):
for name in actor_params:
self.actor_params[name] = actor_params[name]
self.grad_steps = grad_steps
def return_params(self):
return self.actor_params
def return_grad_steps(self):
return self.grad_steps
@ray.remote
class Actor(object):
def __init__(self, state_dim, action_dim, exploration_decay, exploration_min, worker_id=None, replay_buffer=None, param_server=None, push_size=20, num_grad_steps=1e6):
self.worker_id = worker_id
self.env = gym.make('LunarLander-v2')
self.net = Net(state_dim, action_dim)
# get ray_remote objects; centralized buffer and parameter server
self.replay_buffer = replay_buffer
self.param_server = param_server
self.push_size = push_size # this is how much data we need until we push to the centralised buffer
self.num_grad_steps = num_grad_steps
self.epsilon = 1
self.exploration_decay = exploration_decay
self.exploration_min = exploration_min
self.action_dim = action_dim
def act(self, state):
if np.random.uniform() < self.epsilon:
self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
return np.random.randint(0, self.action_dim)
else:
state = torch.FloatTensor(state)
with torch.no_grad():
values = self.net(state)
action = torch.argmax(values)
return int(action)
def sync_with_param_server(self):
new_actor_params = ray.get(self.param_server.return_params.remote())
for param in new_actor_params:
new_actor_params[param] = torch.from_numpy(new_actor_params[param]).float()
self.net.load_state_dict(new_actor_params)
def run(self):
state = self.env.reset()
episode_reward = 0
episode = 0
ep_length = 0
grad_steps = 0
intermediate_memory = [] # this is what we will push to the buffer at once
while grad_steps < self.num_grad_steps:
ep_length += 1
action = self.act(state)
next_state, reward, done, _ = self.env.step(action)
intermediate_memory.append((state, action, reward, next_state, done))
if len(intermediate_memory) >= self.push_size:
ray.get(self.replay_buffer.push.remote(intermediate_memory))
intermediate_memory = []
self.sync_with_param_server()
grad_steps = ray.get(self.param_server.return_grad_steps.remote())
episode_reward += reward
if done:
# print results locally
# print(f"Episode {episode}: {episode_reward}")
# print_status(self.env, time_step)
# prepare new rollout
episode += 1
episode_reward = 0
ep_length = 0
next_state = self.env.reset()
state = next_state
if __name__ == "__main__":
ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=4, push_data_size=20, use_priority=False, max_grad_steps=1e6)
ddqn.learn()
I’ve attached a screen shot of the memory usage for this process. Note that the issue also happens when using a Linux system. It looks like it is the main process that is consuming the most memory.
I’m very confused what would be causing this, so any help would be greatly appreciated! thanks.