Hello all! I created a GTrXL model based on the wrapper model located here: ray/rllib/models/torch/attention_net.py at master · ray-project/ray · GitHub
However I’ve noticed that after every backward pass, my GPU memory usage goes up until eventually I run out of GPU memory and the whole thing crashes.
Below is my model code; does it look like there is a memory leak in there that would be causing this to happen? It is pretty much 100% based on the existing code from the repo with some slight adjustments.
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.modules import GRUGate, RelativeMultiHeadAttention, SkipConnection
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import List, ModelConfigDict, TensorType
from torch import nn
class GTrXLNet(RecurrentNetwork, nn.Module):
"""A GTrXL net Model described in [2].
This is still in an experimental phase.
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
To use this network as a replacement for an RNN, configure your Algorithm
as follows:
Examples:
>> config["model"]["custom_model"] = GTrXLNet
>> config["model"]["max_seq_len"] = 10
>> config["model"]["custom_model_config"] = {
>> num_transformer_units=1,
>> attention_dim=32,
>> num_heads=2,
>> memory_tau=50,
>> etc..
>> }
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
*,
num_transformer_units: int = 1,
attention_dim: int = 256,
num_heads: int = 4,
memory_inference: int = 100,
memory_training: int = 100,
head_dim: int = 128,
position_wise_mlp_dim: int = 128,
init_gru_gate_bias: float = 2.0,
):
"""Initializes a GTrXLNet.
Args:
num_transformer_units: The number of Transformer repeats to
use (denoted L in [2]).
attention_dim: The input and output dimensions of one
Transformer unit.
num_heads: The number of attention heads to use in parallel.
Denoted as `H` in [3].
memory_inference: The number of timesteps to concat (time
axis) and feed into the next transformer unit as inference
input. The first transformer unit will receive this number of
past observations (plus the current one), instead.
memory_training: The number of timesteps to concat (time
axis) and feed into the next transformer unit as training
input (plus the actual input sequence of len=max_seq_len).
The first transformer unit will receive this number of
past observations (plus the input sequence), instead.
head_dim: The dimension of a single(!) attention head within
a multi-head attention unit. Denoted as `d` in [3].
position_wise_mlp_dim: The dimension of the hidden layer
within the position-wise MLP (after the multi-head attention
block within one Transformer unit). This is the size of the
first of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attention_dim`.
init_gru_gate_bias: Initial bias values for the GRU gates
(two GRUs per Transformer unit, one after the MHA, one after
the position-wise MLP).
"""
super().__init__(observation_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.num_transformer_units = num_transformer_units
self.attention_dim = attention_dim
self.num_heads = num_heads
self.memory_inference = memory_inference
self.memory_training = memory_training
self.head_dim = head_dim
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
# 1) Pre-process with FC layers
self.linear_layer = SlimFC(in_size=self.obs_dim, out_size=self.attention_dim)
self.layers = [self.linear_layer]
attention_layers = []
# 2) Create L Transformer blocks according to [2].
for i in range(self.num_transformer_units):
# RelativeMultiHeadAttention part.
MHA_layer = SkipConnection(
RelativeMultiHeadAttention(
in_dim=self.attention_dim,
out_dim=self.attention_dim,
num_heads=num_heads,
head_dim=head_dim,
input_layernorm=True,
output_activation=nn.ReLU,
),
fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias),
)
# Position-wise MultiLayerPerceptron part.
E_layer = SkipConnection(
nn.Sequential(
torch.nn.LayerNorm(self.attention_dim),
SlimFC(
in_size=self.attention_dim,
out_size=position_wise_mlp_dim,
use_bias=False,
activation_fn=nn.ReLU,
),
SlimFC(
in_size=position_wise_mlp_dim,
out_size=self.attention_dim,
use_bias=False,
activation_fn=nn.ReLU,
),
),
fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias),
)
# Build a list of all attanlayers in order.
attention_layers.extend([MHA_layer, E_layer])
# Create a Sequential such that all parameters inside the attention
# layers are automatically registered with this top-level model.
self.attention_layers = nn.Sequential(*attention_layers)
self.layers.extend(attention_layers)
# 3) Post-process with FC layers
self.num_outputs = num_outputs
self.logits = None
self.values_out = None
# Last value output.
self._value_out = None
self.logits = SlimFC(
in_size=self.attention_dim,
out_size=self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.values_out = SlimFC(
in_size=self.attention_dim,
out_size=1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
# Setup trajectory views (`memory-inference` x past memory outs).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attention_dim,))
self.view_requirements[f"state_in_{i}"] = ViewRequirement(
f"state_out_{i}",
shift=f"-{self.memory_inference}:-1",
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space,
)
self.view_requirements[f"state_out_{i}"] = ViewRequirement(space=space, used_for_training=False)
@override(ModelV2)
def forward(self, input_dict, state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]):
assert seq_lens is not None
# Add the needed batch rank (tf Models' Input requires this).
observations = input_dict["obs_flat"]
# Add the time dim to observations.
B = len(seq_lens)
T = observations.shape[0] // B
observations = torch.reshape(observations, [-1, T, *list(observations.shape[1:])])
all_out = observations
memory_outs = []
for i in range(len(self.layers)):
# MHA layers which need memory passed in.
if i % 2 == 1:
all_out = self.layers[i](all_out, memory=state[i // 2])
# Either self.linear_layer (initial obs -> attn. dim layer) or
# MultiLayerPerceptrons. The output of these layers is always the
# memory for the next forward pass.
else:
all_out = self.layers[i](all_out)
memory_outs.append(all_out)
# Discard last output (not needed as a memory since it's the last
# layer).
memory_outs = memory_outs[:-1]
if self.logits is not None:
out = self.logits(all_out)
self._value_out = self.values_out(all_out)
out_dim = self.num_outputs
else:
out = all_out
out_dim = self.attention_dim
return torch.reshape(out, [-1, out_dim]), [torch.reshape(m, [-1, self.attention_dim]) for m in memory_outs]
@override(ModelV2)
def get_initial_state(self) -> List[np.ndarray] | List[TensorType]:
return [torch.zeros(self.view_requirements[f"state_in_{i}"].space.shape) for i in range(self.num_transformer_units)]
@override(ModelV2)
def value_function(self) -> TensorType:
assert self._value_out is not None, "Must call forward first AND must have value branch!"
return torch.reshape(self._value_out, [-1])
ModelCatalog.register_custom_model(
"GTrXLNet",
GTrXLNet,
)