Hi @mannyv,
Thanks for the reply!
Here is the contents of the file with both vanilla and prioritised replay buffer:
from collections import deque
import gym
import copy
from networks import Net
import numpy as np
import random
import ray
import torch
from torch.nn import MSELoss
import time
class DQN:
def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device="cuda" if torch.cuda.is_available() else "cpu"):
self.gamma = gamma
self.action_dim = action_dim
self.state_dim = state_dim
self.memory = deque(maxlen=memory)
self.batch_size = batch_size
self.replay_start_size = replay_start_size
self.exploration_min = exploration_min
self.epsilon = 1
self.exploration_decay = exploration_decay
self.device = device
self.net = Net(self.state_dim, self.action_dim).to(self.device)
self.loss_fn = MSELoss()
self.optimiser = torch.optim.Adam(self.net.parameters(), lr=0.0001)
self.target_net = copy.deepcopy(self.net).to(self.device)
self.loss = []
self.count = 0
self.tau = tau
self.grad_steps = 0
def act(self, state):
if len(self.memory) < self.replay_start_size:
return np.random.randint(0, self.action_dim)
elif np.random.uniform() < self.epsilon:
self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
return np.random.randint(0, self.action_dim)
else:
state = torch.FloatTensor(state).to(self.device)
with torch.no_grad():
values = self.net(state)
action = torch.argmax(values)
return int(action)
def greedy_act(self, state):
state = torch.FloatTensor(state).to(self.device)
with torch.no_grad():
values = self.net(state)
return values.argmax().item()
def remember(self, state, action, reward, next_state, done):
self.memory.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
def experience_replay(self):
if len(self.memory) < self.replay_start_size:
return
states, actions, rewards, next_states, dones = self.get_batch()
# make forward pass of the network
q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
with torch.no_grad():
targets = self.target_net(next_states)
targets, _ = torch.max(targets, dim=1)
targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
targets = targets.squeeze()
self.optimiser.zero_grad()
loss = (q_vals - targets).pow(2).mean()
loss.backward()
self.optimiser.step()
self.grad_steps += 1
self.update_target()
def get_batch(self):
batch = random.sample(self.memory, self.batch_size)
states = []
actions = []
rewards = []
next_states = []
dones = []
for state, action, reward, next_state, done in batch:
states.append(state)
actions.append([action])
rewards.append([reward])
next_states.append(next_state)
dones.append([done])
return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device)
def update_target(self):
for real, target in zip(self.net.parameters(), self.target_net.parameters()):
target.data.copy_(real.data * self.tau + target.data * (1 - self.tau))
def save_model(self):
torch.save(self.net, "dqn")
@ray.remote
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment_per_sampling=0.001, batch_size=128):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.beta_increment_per_sampling = beta_increment_per_sampling
self.buffer = []
self.pos = 0
self.priorities = []
self.batch_size = batch_size
def push(self, data):
for experience in data:
max_priority = max(self.priorities) if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.priorities.append(max_priority)
else:
self.buffer[self.pos] = experience
self.priorities[self.pos] = max_priority
self.pos = (self.pos + 1) % self.capacity
def sample(self):
N = len(self.buffer)
if N == self.capacity:
priorities = np.array(self.priorities)
else:
priorities = np.array(self.priorities[:self.pos])
self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
sampling_probabilities = priorities ** self.alpha
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
indices = random.choices(range(N), k=self.batch_size, weights=sampling_probabilities)
experiences = [self.buffer[idx] for idx in indices]
weights = np.array([(self.capacity * priorities[i]) ** -self.beta for i in indices])
weights = weights / weights.max()
return experiences, np.array(indices), weights
def update_priorities(self, indices, priorities):
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return len(self.buffer)
@ray.remote
class ReplayBuffer:
def __init__(self, capacity, batch_size=128):
self.capacity = capacity
self.buffer = []
self.batch_size = batch_size
def push(self, data):
for experience in data:
self.buffer.append(experience)
def sample(self):
return random.sample(self.buffer, self.batch_size)
def __len__(self):
return len(self.buffer)
class DistributedPDQN(DQN):
def __init__(self, gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, tau, device="cuda" if torch.cuda.is_available() else "cpu", beta=0.4, alpha=0.6, beta_increment_per_sampling=0.001,
num_workers=5, max_grad_steps=5e4, push_data_size=20, use_priority=True):
self.env = gym.make('LunarLander-v2')
self.use_priority = use_priority
action_dim, state_dim = self.env.action_space.n, self.env.observation_space.shape[0]
super(DistributedPDQN, self).__init__(gamma, memory, batch_size, replay_start_size, exploration_min, exploration_decay, action_dim, state_dim, tau, device)
if self.use_priority:
self.memory = PrioritizedReplayBuffer.remote(memory, alpha, beta, beta_increment_per_sampling, batch_size)
else:
self.memory = ReplayBuffer.remote(memory, batch_size)
self.param_server = ParameterServer.remote()
params = dict(self.net.named_parameters())
with torch.no_grad():
for name in params:
params[name] = params[name].cpu().numpy()
self.param_server.define_param_list.remote(params)
self.actors = [Actor.remote(state_dim, action_dim, exploration_decay, exploration_min, i + 1, self.memory, self.param_server, push_data_size, max_grad_steps) for i in range(num_workers)]
self.max_grad_steps = max_grad_steps
self.num_actions = action_dim
self.test_scores = []
def learn(self):
data = []
burnin_ep = 0
while len(data) < 50000:
burnin_ep += 1
state = self.env.reset()
score = 0
done = False
while not done:
action = np.random.randint(0, self.num_actions)
next_state, reward, done, _ = self.env.step(action)
data.append((state.flatten().tolist(), action, reward, next_state.flatten().tolist(), done))
state = next_state
score += reward
print(f"Burn-in episode {burnin_ep + 1}")
self.memory.push.remote(data)
del data
# start the actor
[a.run.remote() for a in self.actors]
# start learning
while self.grad_steps < self.max_grad_steps:
self.experience_replay()
with torch.no_grad():
params = dict(self.net.named_parameters())
for name in params:
params[name] = params[name].cpu().numpy()
self.param_server.update_params.remote(params, self.grad_steps)
print(f"{self.grad_steps} grad steps taken")
if (self.grad_steps + 1) % 500 == 0:
self.run_test_episode()
if (self.grad_steps + 1) % 500 == 0:
self.save_model()
def run_test_episode(self):
state = self.env.reset()
score = 0
done = False
while not done:
action = self.greedy_act(state)
next_state, reward, done, _ = self.env.step(action)
state = next_state
score += reward
self.test_scores.append(score)
def get_batch(self):
start = time.time()
if self.use_priority:
batch, idx, weights = ray.get(self.memory.sample.remote())
else:
batch = ray.get(self.memory.sample.remote())
idx = []
weights = []
states = []
actions = []
rewards = []
next_states = []
dones = []
for state, action, reward, next_state, done in batch:
states.append(state)
actions.append([action])
rewards.append([reward])
next_states.append(next_state)
dones.append([done])
end = time.time()
print(f"getting batch took {(end - start) / 60} minutes")
return torch.FloatTensor(states).to(self.device), torch.LongTensor(actions).to(self.device), torch.FloatTensor(rewards).to(self.device), torch.FloatTensor(next_states).to(self.device), torch.FloatTensor(dones).to(self.device), idx, \
torch.FloatTensor(weights).to(self.device)
def experience_replay(self):
start = time.time()
states, actions, rewards, next_states, dones, sample_idx, weights = self.get_batch()
# make forward pass of the network
q_vals = self.net.forward(states).gather(1, actions).squeeze(1)
with torch.no_grad():
targets = self.target_net(next_states)
targets, _ = torch.max(targets, dim=1)
targets = rewards + self.gamma * (1 - dones) * targets.view(self.batch_size, 1)
targets = targets.squeeze()
td_error = (q_vals - targets).pow(2)
if self.use_priority:
priorities = td_error + 1e-5
self.memory.update_priorities.remote(sample_idx, priorities.data.cpu().numpy())
self.optimiser.zero_grad()
if self.use_priority:
loss = (td_error * weights).mean()
else:
loss = td_error.mean()
loss.backward()
self.optimiser.step()
self.grad_steps += 1
self.update_target()
end = time.time()
print(f"experience replay took {(end - start) / 60} minutes")
@ray.remote
class ParameterServer(object):
def __init__(self):
self.grad_steps = 0
self.actor_params = None
def define_param_list(self, actor_param_dict):
self.actor_params = {}
for name in actor_param_dict:
self.actor_params[name] = actor_param_dict[name]
def update_params(self, actor_params, grad_steps):
for name in actor_params:
self.actor_params[name] = actor_params[name]
self.grad_steps = grad_steps
def return_params(self):
return self.actor_params
def return_grad_steps(self):
return self.grad_steps
@ray.remote
class Actor(object):
def __init__(self, state_dim, action_dim, exploration_decay, exploration_min, worker_id=None, replay_buffer=None, param_server=None, push_size=20, num_grad_steps=1e6):
self.worker_id = worker_id
self.env = gym.make('LunarLander-v2')
self.net = Net(state_dim, action_dim)
# get ray_remote objects; centralized buffer and parameter server
self.replay_buffer = replay_buffer
self.param_server = param_server
self.push_size = push_size # this is how much data we need until we push to the centralised buffer
self.num_grad_steps = num_grad_steps
self.epsilon = 1
self.exploration_decay = exploration_decay
self.exploration_min = exploration_min
self.action_dim = action_dim
def act(self, state):
if np.random.uniform() < self.epsilon:
self.epsilon = max(self.epsilon * self.exploration_decay, self.exploration_min)
return np.random.randint(0, self.action_dim)
else:
state = torch.FloatTensor(state)
with torch.no_grad():
values = self.net(state)
action = torch.argmax(values)
return int(action)
def sync_with_param_server(self):
new_actor_params = ray.get(self.param_server.return_params.remote())
for param in new_actor_params:
new_actor_params[param] = torch.from_numpy(new_actor_params[param]).float()
self.net.load_state_dict(new_actor_params)
def run(self):
state = self.env.reset()
episode_reward = 0
episode = 0
ep_length = 0
grad_steps = 0
intermediate_memory = [] # this is what we will push to the buffer at once
while grad_steps < self.num_grad_steps:
ep_length += 1
action = self.act(state)
next_state, reward, done, _ = self.env.step(action)
intermediate_memory.append((state, action, reward, next_state, done))
if len(intermediate_memory) >= self.push_size:
self.replay_buffer.push.remote(intermediate_memory)
intermediate_memory = []
self.sync_with_param_server()
grad_steps = ray.get(self.param_server.return_grad_steps.remote())
# time.sleep(60 * 5)
episode_reward += reward
if done:
# print results locally
# print(f"Episode {episode}: {episode_reward}")
# print_status(self.env, time_step)
# prepare new rollout
episode += 1
episode_reward = 0
ep_length = 0
next_state = self.env.reset()
state = next_state
You can then run an instance of this with:
if __name__ == "__main__":
ddqn = DistributedPDQN(0.99, 50000, 256, 10000, 0.01, 0.9995, tau=0.0025, num_workers=5, push_data_size=20, use_priority=True, max_grad_steps=2)
ddqn.learn()
switching the use_priority
flag to False will use the vanilla replay buffer instead of the prioritised buffer.
From what you have said, I’m not sure why this would cause a problem with the prioritised buffer and not the vanilla buffer? both push experience in a similar manner, so it’s not clear to me why the prioritised buffer would be causing such a big slow down in time.