Custom model with LSTM crashes PPO

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello everyone!
I am trying to use PPO with a custom model that implements an LSTM and action masking. Without LSTM everything works (see the model implementation below) but when I try to add LSTM components (see implementation below) I get an internal RLlib error from

  1. (PPO pid=29379) 2023-11-24 13:09:26,043 ERROR – Ray error, taking actor 1 out of service. ray::RolloutWorker.apply() (pid=29524, ip=, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f836c82aca0>)
  2. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/utils/”, line 183, in apply
  3. (PPO pid=29379) raise e
  4. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/utils/”, line 174, in apply
  5. (PPO pid=29379) return func(self, *args, **kwargs)
  6. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/execution/”, line 86, in
  7. (PPO pid=29379) lambda w: w.sample(), local_worker=False, healthy_only=True
  8. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/”, line 900, in sample
  9. (PPO pid=29379) batches = []
  10. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/”, line 92, in next
  11. (PPO pid=29379) batches = [self.get_data()]
  12. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/”, line 285, in get_data
  13. (PPO pid=29379) item = next(self._env_runner)
  14. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/”, line 706, in _env_runner
  15. (PPO pid=29379) ] = _process_policy_eval_results(
  16. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/”, line 1297, in _process_policy_eval_results
  17. (PPO pid=29379) env_id: int = eval_data[i].env_id

due to the fact that in _process_policy_eval_results() of the eval_data has 3 elements, while actions (coming from the following sequence of operations in _process_policy_eval_results()) has 9 elements:

actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions) 
if isinstance(actions, list):
    actions = np.array(actions)
actions: List[EnvActionType] = unbatch(actions)
I am setting the following configuration for running the code:
    'num_rollout_workers': 10,  # Can be increased for faster convergence
    'num_cpus_per_worker': 1,
    # ... other PPO params
    "model": {
        "custom_model": "my_model",
        "max_seq_lens": 50,
        "use_lstm": False

Am I doing something wrong?


from typing import Dict, List, Tuple

from gym.spaces import Dict as GymDict, Discrete
from gym.spaces.utils import flatten_space
from import FullyConnectedNetwork
from import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()

name = "ppo_model"

class PPOModel(TFModelV2):

    def __init__(self,

        super().__init__(obs_space, action_space, num_outputs, model_config, name)
       # Get original observation space (not flattened)
        orig_space_with_mask = getattr(obs_space, "original_space", obs_space)

        # Save original observation space (without action_mask)
        self.orig_space = orig_space_with_mask['observation']
        # Build internal model
        flatten_obs_space = flatten_space(self.orig_space)  # Flatten obs
        self.internal_model = FullyConnectedNetwork(
            name + "_internal",

    def forward(self, input_dict, state, seq_lens):
        # 1) Compute the true observation
        actual_obs_list = self._collect_true_observation(self.orig_space, input_dict["obs"]["observation"])
        actual_obs_flat = tf.concat(actual_obs_list, -1)

        # 2) Compute the model output
        model_out, _ = self.internal_model({'obs': actual_obs_flat})
        # 3) Save mask
        action_mask = None
        if 'mask' in input_dict['obs']:
            action_mask = input_dict["obs"]["mask"]
        if action_mask is not None:
            inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
            model_out = model_out + inf_mask
        return model_out, state

    def value_function(self):
        return self.internal_model.value_function()
    def _collect_true_observation(self, observation_space: dict, observation: dict) -> List[TensorType]:
        flattened_obs = []
        for key in observation_space.keys():
            if isinstance(observation_space[key], GymDict):
                flattened_obs.extend(self._collect_true_observation(observation_space[key], observation[key]))
        return flattened_obs


import numpy as np
import tree
from typing import Dict, List, Tuple

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from gym.spaces import Dict as GymDict, Discrete, flatten_space, MultiDiscrete
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from import RecurrentNetwork
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import one_hot
from ray.rllib.policy.view_requirement import ViewRequirement

from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()

name = "ppo_model"

class PPOModel(RecurrentNetwork):
    Custom LSTM model that implements action masking by setting the logits values of unavailable
    actions with tf.float32.min.

    Can be used for PPO, APPO, IMPALA.

    Data flow (!):
    obs -> forward(forward_rnn()) -> logits

    def __init__(
            obs_space: GymDict,
            action_space: Discrete,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
        :param obs_space: Observation space of the target gym env. This may have an `original_space`
        attribute that specifies how to unflatten the tensor into a ragged tensor.
        :param: action_space: Action space of the target gym env.
        :param: num_outputs (int): Number of output units of the model. Rllib requirement but not used here.
        model_config (ModelConfigDict): Config for the model specify in the trainer config.
        name (str): Name (scope) for the model.

        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        # 1) Pre-process observation space to remove action mask
        orig_obs_space = getattr(obs_space, "original_space", obs_space)
        self.orig_space_wo_mask = orig_obs_space['observation']
        self.obs_dim = int(np.product(flatten_space(self.orig_space_wo_mask).shape))

        # 2) Prepare input: Add prev_action and prev_reward if needed (copy from LSTMModel)
        input_size = 0
        input_size += self.obs_dim

        self.use_prev_action = model_config.get("lstm_use_prev_action", False)
        self.use_prev_reward = model_config.get("lstm_use_prev_reward", False)

        self.action_dim = 0
        self.action_space_struct = get_base_struct_from_space(self.action_space)
        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
                self.action_dim += int(len(space))

        if self.use_prev_action:
            input_size += self.action_dim

        if self.use_prev_reward:
            input_size += 1

        input_layer = tf.keras.layers.Input(shape=(None, input_size), name="inputs")

        # 3) Preprocess observation with a hidden layer and send to LSTM cell
        hiddens = model_config.get("fcnet_hiddens", [])
        layers = input_layer
        for i in range(len(hiddens)):
            layers = tf.keras.layers.Dense(
                name="pre_process_hidden_%d" % i)(layers)

        # 4) Build LSTM layer
        self.cell_size = model_config["lstm_cell_size"]

        state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h")
        state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c")
        seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)

        lstm_out, state_h, state_c = tf.keras.layers.LSTM(
            self.cell_size, return_sequences=True, return_state=True, name="lstm"
            initial_state=[state_in_h, state_in_c],

        # 5) Post-process LSTM output with another hidden layer and compute values
        logits = tf.keras.layers.Dense(
            self.action_dim, activation=tf.keras.activations.linear, name="logits"
        values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)

        # 6) Create the RNN model
        self._rnn_model = tf.keras.Model(
            inputs=[input_layer, seq_in, state_in_h, state_in_c],
            outputs=[logits, values, state_h, state_c],

        # Add prev-a/r to this model's view, if required.
        if self.use_prev_action:
            self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
                SampleBatch.ACTIONS, space=self.action_space, shift=-1
        if self.use_prev_reward:
            self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
                SampleBatch.REWARDS, shift=-1

    def forward(self,
                input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        Compute the model logits via a LSTM.

        Step 1) Compute the LSTM input.
        For the observation, do not use input_dict['obs_flat'] here, because it contains the action_mask, which is
        not used in the model. Instead, compute the flattened observation from the original observation dictionary.
        Step 2) Compute the LSTM output.
        Step 3) Apply mask

        :param input_dict:  Dictionary with keys 'obs' (the original observation dictionary) and
                            'obs_flat' (the flattened observation)
        :param state:       LSTM state (h and c)
        :param seq_lens:    Input sequence lengths

        :return model_out: Tensor with the model output
        :return state:     New LSTM state
        assert seq_lens is not None

        # 1) LSTM input

        # True observation (avoid mask being also passed as an observation)
        actual_obs_list = self._collect_true_observation(self.orig_space_wo_mask, input_dict["obs"]["observation"])
        lstm_input = tf.reshape(tf.concat(actual_obs_list, -1), [-1, self.obs_dim])  # [B*T, obs_dim]

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            prev_a = one_hot(prev_a, self.action_space)  # [B, T, A]
            prev_a = tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])  # [B*T, A]

        if self.use_prev_reward:
            prev_r = input_dict[SampleBatch.PREV_REWARDS]
            prev_r = tf.reshape(tf.cast(prev_r, tf.float32), [-1, 1])  # [B*T, 1]

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            lstm_input = tf.concat([lstm_input] + prev_a_r, axis=1)  # [B*T, obs_dim+A+1]

        inputs = add_time_dimension(padded_inputs=lstm_input, seq_lens=seq_lens, framework="tf")  # [B, T, obs_dim+A+1]

        # 2) Output
        logits, new_state = self.forward_rnn(inputs, state, seq_lens)

        # 3) Apply mask
        action_mask = input_dict["obs"]["mask"]
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        masked_logits = logits + inf_mask

        return tf.reshape(masked_logits, [-1, self.action_dim]), new_state

    def forward_rnn(
        self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
    ) -> Tuple[TensorType, List[TensorType]]:
        model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state)
        return model_out, [h, c]

    def get_initial_state(self) -> List[np.ndarray]:
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    def import_from_h5(self, h5_file: str) -> None:
        raise NotImplementedError('Import from h5 not implemented.')

    def _collect_true_observation(self, observation_space: dict, observation: dict) -> List[TensorType]:
        flattened_obs = []
        for key in observation_space.keys():
            if isinstance(observation_space[key], GymDict):
                flattened_obs.extend(self._collect_true_observation(observation_space[key], observation[key]))
        return flattened_obs