How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hi,
I’ve created a custom LSTM model shown below, but it didn’t work very well in my environment, a personal project.
So, I tested out the model in easier environments (i.e. some classical control environments).
I tried CartPole-v1, Stateless CartPole, Pendulum, and Stateless Pendulum environments.
However, the model didn’t work well in the environments, even though some of them are fully observable…
I have no idea why the model performs poorly.
I’ve tried different model sizes, learning rates, vf_loss_coeff, max_seq_len, … .
Note that the model consists of the actor and critic networks without sharing any layers.
( obs -fc1-lstm-fc2-output (either logits or value) in each network)
I’m using PPO for RL algo.
You can see the custom model and a runnable code below.
Custom LSTM Model:
import numpy as np
from ray.rllib.models.torch.misc import SlimFC, normc_initializer
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
class MinimalLSTMShorter(TorchRNN, nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
fc_sizes=128,
lstm_state_size=128,
post_fc_sizes=128,
value_fc_sizes=128,
value_lstm_state_size=128,
value_post_fc_sizes=128,
):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
# Get configuration for this custom model
self.fc_sizes = fc_sizes
self.lstm_state_size = lstm_state_size
self.post_fc_sizes = post_fc_sizes
self.value_fc_sizes = value_fc_sizes
self.value_lstm_state_size = value_lstm_state_size
self.value_post_fc_sizes = value_post_fc_sizes
# Define observation size
self.obs_size = get_preprocessor(obs_space)(obs_space).size
# Base outputs before feeding into the last branches
self._features = None
self._values = None
# Actor
self.actor_fc1 = nn.Linear(self.obs_size, self.fc_sizes)
self.actor_lstm = nn.LSTM(self.fc_sizes, self.lstm_state_size, batch_first=True)
self.actor_fc2 = nn.Linear(self.lstm_state_size, self.post_fc_sizes)
self.action_branch = nn.Linear(self.post_fc_sizes, num_outputs)
# Critic
self.value_fc1 = nn.Linear(self.obs_size, self.value_fc_sizes)
self.value_lstm = nn.LSTM(self.value_fc_sizes, self.value_lstm_state_size, batch_first=True)
self.value_fc2 = nn.Linear(self.value_lstm_state_size, self.value_post_fc_sizes)
self.value_branch = nn.Linear(self.value_post_fc_sizes, 1)
@override(ModelV2)
def get_initial_state(self):
# Place hidden states on same device as model.
h = [
self.actor_fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.actor_fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.value_fc1.weight.new(1, self.value_lstm_state_size).zero_().squeeze(0),
self.value_fc1.weight.new(1, self.value_lstm_state_size).zero_().squeeze(0),
]
return h
@override(ModelV2)
def value_function(self):
assert self._values is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._values), [-1])
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
# Compute actor outputs
x = nn.functional.relu(self.actor_fc1(inputs))
x, [h1, c1] = self.actor_lstm(x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
self._features = nn.functional.relu(self.actor_fc2(x))
action_out = self.action_branch(self._features)
# Compute critic outputs
x2 = nn.functional.relu(self.value_fc1(inputs))
x2, [h2, c2] = self.value_lstm(x2, [torch.unsqueeze(state[2], 0), torch.unsqueeze(state[3], 0)])
self._values = nn.functional.relu(self.value_fc2(x2))
return action_out, [torch.squeeze(h1, 0), torch.squeeze(c1, 0), torch.squeeze(h2, 0), torch.squeeze(c2, 0)]
Runnable Code:
import ray
from ray import tune
from ray.tune.registry import register_env
# Envs
# Please get environments you are interested in
# Models
from minimal_custom_lstm import MinimalLSTMShorter # the custom model shown above
from ray.rllib.models import ModelCatalog
if __name__ == "__main__":
# Initialize ray
ray.init()
# Register the model
ModelCatalog.register_custom_model(
"RNNModel2", MinimalLSTMShorter
)
target_model = "RNNModel2"
# Register your environment
# register_env("StatelessCartPole", lambda _: StatelessCartPole())
target_env = "CartPole-v1"
tune.run(
"PPO",
config={
"env": target_env,
"framework": "torch",
"num_gpus": 1,
"num_workers": 4,
# Must be fine-tuned when sharing vf-policy layers
"vf_loss_coeff": 0.01,
"lr": 5e-4,
"model": {
# == LSTM ==
# Max seq len for training the LSTM, defaults to 20.
"max_seq_len": 4, # 20
"custom_model": target_model,
},
},
)