Custom LSTM model doesn't perform well

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.


I’ve created a custom LSTM model shown below, but it didn’t work very well in my environment, a personal project.

So, I tested out the model in easier environments (i.e. some classical control environments).
I tried CartPole-v1, Stateless CartPole, Pendulum, and Stateless Pendulum environments.
However, the model didn’t work well in the environments, even though some of them are fully observable…

I have no idea why the model performs poorly.
I’ve tried different model sizes, learning rates, vf_loss_coeff, max_seq_len, … .
Note that the model consists of the actor and critic networks without sharing any layers.
( obs -fc1-lstm-fc2-output (either logits or value) in each network)

I’m using PPO for RL algo.

You can see the custom model and a runnable code below.

Custom LSTM Model:

import numpy as np
from ray.rllib.models.torch.misc import SlimFC, normc_initializer

from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

class MinimalLSTMShorter(TorchRNN, nn.Module):
    def __init__(
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        # Get configuration for this custom model
        self.fc_sizes = fc_sizes
        self.lstm_state_size = lstm_state_size
        self.post_fc_sizes = post_fc_sizes
        self.value_fc_sizes = value_fc_sizes
        self.value_lstm_state_size = value_lstm_state_size
        self.value_post_fc_sizes = value_post_fc_sizes

        # Define observation size
        self.obs_size = get_preprocessor(obs_space)(obs_space).size

        # Base outputs before feeding into the last branches
        self._features = None
        self._values = None

        # Actor
        self.actor_fc1 = nn.Linear(self.obs_size, self.fc_sizes)
        self.actor_lstm = nn.LSTM(self.fc_sizes, self.lstm_state_size, batch_first=True)
        self.actor_fc2 = nn.Linear(self.lstm_state_size, self.post_fc_sizes)
        self.action_branch = nn.Linear(self.post_fc_sizes, num_outputs)

        # Critic
        self.value_fc1 = nn.Linear(self.obs_size, self.value_fc_sizes)
        self.value_lstm = nn.LSTM(self.value_fc_sizes, self.value_lstm_state_size, batch_first=True)
        self.value_fc2 = nn.Linear(self.value_lstm_state_size, self.value_post_fc_sizes)
        self.value_branch = nn.Linear(self.value_post_fc_sizes, 1)

    def get_initial_state(self):
        # Place hidden states on same device as model.
        h = [
  , self.lstm_state_size).zero_().squeeze(0),
  , self.lstm_state_size).zero_().squeeze(0),
  , self.value_lstm_state_size).zero_().squeeze(0),
  , self.value_lstm_state_size).zero_().squeeze(0),
        return h

    def value_function(self):
        assert self._values is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._values), [-1])

    def forward_rnn(self, inputs, state, seq_lens):
        # Compute actor outputs
        x = nn.functional.relu(self.actor_fc1(inputs))
        x, [h1, c1] = self.actor_lstm(x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
        self._features = nn.functional.relu(self.actor_fc2(x))
        action_out = self.action_branch(self._features)

        # Compute critic outputs
        x2 = nn.functional.relu(self.value_fc1(inputs))
        x2, [h2, c2] = self.value_lstm(x2, [torch.unsqueeze(state[2], 0), torch.unsqueeze(state[3], 0)])
        self._values = nn.functional.relu(self.value_fc2(x2))

        return action_out, [torch.squeeze(h1, 0), torch.squeeze(c1, 0), torch.squeeze(h2, 0), torch.squeeze(c2, 0)]

Runnable Code:

import ray
from ray import tune
from ray.tune.registry import register_env
# Envs
# Please get environments you are interested in 
# Models
from minimal_custom_lstm import MinimalLSTMShorter  # the custom model shown above
from ray.rllib.models import ModelCatalog

if __name__ == "__main__":
    # Initialize ray

    # Register the model
        "RNNModel2", MinimalLSTMShorter
    target_model = "RNNModel2"

    # Register your environment
    # register_env("StatelessCartPole", lambda _: StatelessCartPole())
    target_env = "CartPole-v1"
                "env": target_env,
                "framework": "torch",
                "num_gpus": 1,
                "num_workers": 4,
                # Must be fine-tuned when sharing vf-policy layers
                "vf_loss_coeff": 0.01,
                "lr": 5e-4,
                "model": {
                    # == LSTM ==
                    # Max seq len for training the LSTM, defaults to 20.
                    "max_seq_len": 4,  # 20
                    "custom_model": target_model,

Hi @JayCarrot,

There should not be a relu here at least not if your environment returns negative rewards.

You also probably don’t want that extra linear layer between the lstm and the output adapter layer (actor_fc2).

A bptt of 4 is pretty short that should be longer.

The default number of sgd_iters is pretty large for when using an lstm.

I usually prefer A2C when using an lstm because it does not have the inner sgd loop during optimization.

Hi, @mannyv

Thanks for your comments!

it was a mistake to put a relu there, when I tried to make the code simpler and more readable.
In the original model I cretated, self._value looks like:

        # Get an output of the value network
        x_value_fc_net = self.value_fc_net(inputs)
        x_value_lstm, [h2, c2] = self.value_lstm(
            x_value_fc_net, [torch.unsqueeze(state[2], 0),
                             torch.unsqueeze(state[3], 0)
        if self.value_post_fc_sizes:
            self._values = self.value_post_fc_net(x_value_lstm)
            self._values = x_value_lstm


Did you mean bptt max_seq_len?
I tried 4 to 10 of max_seq_len, all of which didn’t work tho…

I’ll try different sgd_iters and A2C for simpler optimization steps.

An update for people having similar issues:


I have tried to use different num_sgd_iter values, and it seemed to work. 2-10 were good choices. I had used 40, the default value, tho.
For example, StatelessPendulum environment (angular velocity unobservable) worked with:
I learned that the performance was quite sensitive to combinations of the learning rate, num_sgd_iter, and the batch sizes. Also, learning rate annealing was required for better performance and stability in some environment.


It was not critical for a success of learning, but longer ones in the range of 4-20 (with the adequate optimization settings) showed better performance in many cases.

Different Algos

I didn’t test simpler RL algorithms such as A2C or REINFORCE-based ones, but I agree with @mannyv, considering the performance was fragile with the changes in optimization settings of PPO.