How severe does this issue affect your experience of using Ray?
- None: Just asking a question out of curiosity
Hello all,
Forgive me if most of this information is contained somewhere, but I have a general question for the RLLIB team in terms of contributing.
I have developed an Epistemic Neural Network wrapper and a Mixture of Gaussian module for continuous action spaces that I would like to create a PR for. First, I wanted to see if there is any interest in the concepts. I will briefly go over each to just get the gist of what it is they are accomplishing. I have also built centralized critics, stacked state models, and diffusion models for state planning. I am currently working on multi-task models that will use hierarchical decision making as an insight into some of the efforts I work on. All are based on PPO / continuous action algorithms.
MOG module:
This one is pretty straight forward, but stems from the idea here - DeepMind paper. It allows the network (preferably the value network) to express multimodal behavior with multiple gaussians. The user has the ability to specify the number of gaussians, hidden layers, number of layers, and activation function to use. Uses the negative log-likelihood as the loss function.
Tested on: Several gym environments, custom environments with fixed-wing / quadcopters and PyFlyt environments
class MOG(nn.Module):
def __init__(self, obs_space, num_gaussians, hidden_layer_dims = None, num_layers = None, activation = None):
super(MOG, self).__init__()
nn.Module.__init__(self)
self.elu = torch.nn.ELU()
self.num_gaussians = num_gaussians
self.activation_fn = activation if activation is not None else 'LeakyReLU'
self.num_layers = num_layers if num_layers is not None else 2
self.hidden_layer_dims = hidden_layer_dims if hidden_layer_dims is not None else 128
if self.activation_fn in activation_functions:
self.activation_fn = activation_functions[self.activation_fn]()
layers = []
for i in range(num_layers):
input_dim = obs_space.shape[0]
in_features = input_dim if i == 0 else hidden_layer_dims
layers.append(nn.Linear(in_features, hidden_layer_dims))
layers.append(self.activation_fn)
self.hidden_layers = nn.Sequential(*layers)
self.output_layer = nn.Linear(hidden_layer_dims, self.num_gaussians * 3)
def forward(self, input_dict, state, seq_lens):
obs_raw = input_dict['obs_flat'].float()
obs = obs_raw.reshape(obs_raw.shape[0], -1)
logits = self.hidden_layers(obs)
value_output = self.output_layer(logits)
# get gaussians components
means = value_output[:, :self.num_gaussians]
self._u = means
sigmas_prev = value_output[:, self.num_gaussians:self.num_gaussians*2]
sigmas = torch.nn.functional.softplus(sigmas_prev) + 1e-6
self._sigmas = sigmas
alphas = value_output[:, self.num_gaussians*2:]
alphas = torch.clamp(torch.nn.functional.softmax(alphas, dim=-1), 1e-6, None)
self._alphas = alphas
return value_output, state
def value_function(self, means = None, alphas = None):
# values of the forward pass is simply the gaussian means multiplied by their respective alpha
# give the user the option to pass means and alphas so they have the ability to graph, etc.
if means is not None and alphas is not None:
value = torch.sum(means * alphas, dim = 1)
else:
value = torch.sum(self._u * self._alphas, dim = 1)
return value
def predict_gmm_params(self, obs):
logits = self.hidden_layers(obs)
value_output = self.output_layer(logits)
# get gaussians components
means = value_output[:, :self.num_gaussians]
sigmas_prev = value_output[:, self.num_gaussians:self.num_gaussians*2]
sigmas = torch.nn.functional.softplus(sigmas_prev) + 1e-6
alphas = value_output[:, self.num_gaussians*2:]
# run through softmax later since we do the logsumexp
return means, sigmas, alphas
def compute_log_likelihood(self, td_targets, mu_current, sigma_current, alpha_current):
td_targets_expanded = td_targets.unsqueeze(1)
sigma_clamped = sigma_current
log_2_pi = torch.log(2*torch.tensor(math.pi))
factor = -torch.log(sigma_clamped) - 0.5*log_2_pi
mus = td_targets_expanded - mu_current
logp = torch.clamp(factor - torch.square(mus)/ (2*torch.square(sigma_clamped)), -1e10, 10)
loga = torch.clamp(torch.nn.functional.log_softmax(alpha_current, dim=-1), 1e-6, None)
summing_log = -torch.logsumexp(logp + loga, dim=-1)
return summing_log
def custom_loss(self, sample_batch, gamma = None):
gamma = gamma if gamma is not None else 0.99
cur_obs = sample_batch[SampleBatch.CUR_OBS]
next_obs = sample_batch[SampleBatch.NEXT_OBS]
rewards = sample_batch[SampleBatch.REWARDS]
dones = sample_batch[SampleBatch.DONES]
mu_current, sigma_current, alpha_current = self.predict_gmm_params(cur_obs)
mu_next, sigma_next, alpha_next = self.predict_gmm_params(next_obs)
alpha_next = torch.clamp(torch.nn.functional.softmax(alpha_next, dim=-1), 1e-6, None)
next_state_values = torch.sum(mu_next * alpha_next, dim=1).clone().detach()
td_targets = rewards + gamma * next_state_values * (1 - dones.float())
log_likelihood = self.compute_log_likelihood(td_targets, mu_current, sigma_current, alpha_current)
log_likelihood = torch.clamp(log_likelihood, -10, 80)
nll_loss = torch.mean(log_likelihood)
return nll_loss
ENN:
This one took a little while to flush out. Inspired by Osband’s ENN paper and Thompson Sampling application it is a wrapper around a base network that will inject initial uncertainty at each state to encourage exploration. In a very brief explanation, it essentially changes the prior by nudging them by the z-index (think if z-index is 5 this is 5 different priors) to see what the model knows. If the model is particularity certain about a state it will nearly predict the same output each time; if it is not certain about a state it will predict different outputs.
Tested on: Dogfighting scenario using PyFlyt. Reward was increased by 3,150% and convergence time was reduced by half (to a stable reward). This was ran over 100 times at 100M timesteps each against MOG networks, normal networks (basic vanilla PPO) and switching Agent 1 and Agent 2 every other iteration. This was also tested on gym environments like Half-Cheetah and performs worse than normal PPO by 5-15%; I presume this is due to these environments having little epistemic uncertainty as compared to a Dogfighting scenario.
class ENNWrapper(nn.Module):
def __init__(self, base_network, z_dim, enn_layer, activation = None, initializer = None):
super(ENNWrapper, self).__init__()
"""
Args:
base_network: network that is wrapped with the ENN
z_dim: number of dimensions for the multivariate gaussian distribution
-- This can be seen as the number of models (mimicking the ensemble approach with noise)
enn_layer: layer size for the enn
hidden_layer: base network layer size
activation: activation function to use for the base and enn networks
initializer: network initializer to use
-- Recommended to leave default per https://arxiv.org/abs/2302.09205
"""
self.std = 1.0
self.mean = 0.0
self.z_dim = z_dim
self.step_number = 0
self.z_indices = None
self.step_cut_off = 100
self.activation_fn = activation if activation is not None else 'LeakyReLU'
self.initializer = initializer if initializer is not None else torch.nn.init.xavier_normal_
self.distribution = Normal(torch.full((self.z_dim,), self.mean), torch.full((self.z_dim,), self.std))
if activation in activation_functions:
activation = activation_functions[activation]()
else:
raise ValueError("Unsupported activation function")
def collect_layers(module):
layers = []
for m in module.children():
if isinstance(m, SlimFC):
layers.extend(list(m._model.children()))
elif isinstance(m, nn.Sequential):
layers.extend(collect_layers(m))
else:
layers.append(m)
return layers
def get_last_layer_input_features(layers):
for layer in reversed(layers):
if isinstance(layer, nn.Linear):
return layer.in_features
return None
# collect the layers from the base network
hidden_layers = collect_layers(base_network._hidden_layers)
hidden_layer_size = get_last_layer_input_features(hidden_layers)
if base_network._logits:
last_layer = list(base_network._logits.children())
else:
last_layer = []
# create a new sequential model with the hidden layers followed by the last layer
self.base_network = nn.Sequential(*hidden_layers)
self.last_layer = nn.Sequential(*last_layer)
self.learnable_layers = nn.Sequential(
SlimFC(hidden_layer_size + 1, enn_layer, initializer=self.initializer,
activation_fn=self.activation_fn),
SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn),
SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)
)
self.prior_layers = nn.Sequential(
SlimFC(hidden_layer_size + 1, enn_layer, initializer=self.initializer,
activation_fn=self.activation_fn),
SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn),
SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)
)
def forward(self, input_dict, state, seq_lens):
# get intermediate logits (second before last layer)
obs_raw = input_dict['obs_flat'].float()
obs = obs_raw.reshape(obs_raw.shape[0], -1)
base_output, enn_out = self.pass_through_layers(obs)
self.total_output = enn_out + base_output
return self.total_output, state
def pass_through_layers(self, obs):
with torch.no_grad():
intermediate = self.base_network(obs)
base_output = self.last_layer(intermediate)
intermediate_unsqueeze = torch.unsqueeze(intermediate, 1)
# draw sample from distribution and cat to logits
self.z_samples = self.distribution.sample((obs.shape[0],)).unsqueeze(-1).to(obs.device)
enn_input = torch.cat((self.z_samples, intermediate_unsqueeze.expand(-1, self.z_dim, -1)), dim=2)
# enn, prior and base network pass
if self.step_number < self.step_cut_off:
# only updated prior for xx timesteps
prior_out = self.prior_layers(enn_input)
else:
with torch.no_grad():
# this now encapsulates the uncertainty and will inject into each timestep
prior_out = self.prior_layers(enn_input)
prior_bmm = torch.bmm(torch.transpose(prior_out, 1, 2), self.z_samples)
prior = prior_bmm.squeeze(-1)
# pass through learnable part of the ENN
learnable_out = self.learnable_layers(enn_input)
learnable_bmm = torch.bmm(torch.transpose(learnable_out, 1, 2), self.z_samples)
learnable = learnable_bmm.squeeze(-1)
enn_output = learnable + prior
return base_output, enn_output
def enn_loss(self, sample_batch, handle_loss, gamma = None):
cur_obs = sample_batch[SampleBatch.CUR_OBS]
next_obs = sample_batch[SampleBatch.NEXT_OBS]
rewards = sample_batch[SampleBatch.REWARDS]
dones = sample_batch[SampleBatch.DONES]
gamma = gamma if gamma is not None else 0.99
next_base_output, next_enn_output = self.pass_through_layers(next_obs)
next_values = next_base_output + next_enn_output
next_values = next_values.squeeze(-1) if next_values.shape[-1] == 1 else next_values
target = rewards + gamma * next_values.clone().detach() * (1 - dones.float())
enn_loss = torch.nn.functional.mse_loss(self.total_output.squeeze(-1), target)
if handle_loss == True:
intermediate = self.base_network(cur_obs)
base_output = self.last_layer(intermediate).squeeze(-1)
base_target = rewards + gamma * next_base_output.squeeze(-1) * (1 - dones.float())
critic_loss = torch.nn.functional.mse_loss(base_output, base_target)
total_loss = enn_loss + critic_loss
else:
total_loss = enn_loss
return total_loss
If this is something that is of interest let me know and I can work on a PR.
All the best,
Tyler