Custom model with LSTM crashes PPO sampler.py

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello everyone!
I am trying to use PPO with a custom model that implements an LSTM and action masking. Without LSTM everything works (see the model implementation below) but when I try to add LSTM components (see implementation below) I get an internal RLlib error from sampler.py:

  1. (PPO pid=29379) 2023-11-24 13:09:26,043 ERROR actor_manager.py:486 – Ray error, taking actor 1 out of service. ray::RolloutWorker.apply() (pid=29524, ip=192.168.1.53, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f836c82aca0>)
  2. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/utils/actor_manager.py”, line 183, in apply
  3. (PPO pid=29379) raise e
  4. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/utils/actor_manager.py”, line 174, in apply
  5. (PPO pid=29379) return func(self, *args, **kwargs)
  6. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/execution/rollout_ops.py”, line 86, in
  7. (PPO pid=29379) lambda w: w.sample(), local_worker=False, healthy_only=True
  8. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py”, line 900, in sample
  9. (PPO pid=29379) batches = [self.input_reader.next()]
  10. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py”, line 92, in next
  11. (PPO pid=29379) batches = [self.get_data()]
  12. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py”, line 285, in get_data
  13. (PPO pid=29379) item = next(self._env_runner)
  14. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py”, line 706, in _env_runner
  15. (PPO pid=29379) ] = _process_policy_eval_results(
  16. (PPO pid=29379) File “/home/fedetask/Desktop/vtl/venv/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py”, line 1297, in _process_policy_eval_results
  17. (PPO pid=29379) env_id: int = eval_data[i].env_id

due to the fact that in _process_policy_eval_results() of sampler.py the eval_data has 3 elements, while actions (coming from the following sequence of operations in _process_policy_eval_results()) has 9 elements:

actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions) 
if isinstance(actions, list):
    actions = np.array(actions)
actions: List[EnvActionType] = unbatch(actions)
I am setting the following configuration for running the code:
{
    'num_rollout_workers': 10,  # Can be increased for faster convergence
    'num_cpus_per_worker': 1,
    # ... other PPO params
    "model": {
        "custom_model": "my_model",
        "max_seq_lens": 50,
        "use_lstm": False
    }
}

Am I doing something wrong?

NO-LSTM IMPLEMENTATION

from typing import Dict, List, Tuple

from gym.spaces import Dict as GymDict, Discrete
from gym.spaces.utils import flatten_space
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()

name = "ppo_model"



class PPOModel(TFModelV2):

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 **kw):

        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        
       # Get original observation space (not flattened)
        orig_space_with_mask = getattr(obs_space, "original_space", obs_space)

        # Save original observation space (without action_mask)
        self.orig_space = orig_space_with_mask['observation']
       
        # Build internal model
        flatten_obs_space = flatten_space(self.orig_space)  # Flatten obs
        
        self.internal_model = FullyConnectedNetwork(
            flatten_obs_space,
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

    def forward(self, input_dict, state, seq_lens):
        
        # 1) Compute the true observation
        actual_obs_list = self._collect_true_observation(self.orig_space, input_dict["obs"]["observation"])
        actual_obs_flat = tf.concat(actual_obs_list, -1)

        # 2) Compute the model output
        model_out, _ = self.internal_model({'obs': actual_obs_flat})
                
        # 3) Save mask
        action_mask = None
        if 'mask' in input_dict['obs']:
            action_mask = input_dict["obs"]["mask"]
        if action_mask is not None:
            inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
            model_out = model_out + inf_mask
        
        return model_out, state

    def value_function(self):
        return self.internal_model.value_function()
    
    def _collect_true_observation(self, observation_space: dict, observation: dict) -> List[TensorType]:
        flattened_obs = []
        for key in observation_space.keys():
            if isinstance(observation_space[key], GymDict):
                flattened_obs.extend(self._collect_true_observation(observation_space[key], observation[key]))
            else:
                flattened_obs.append(observation[key])
        return flattened_obs

LSTM IMPLEMENTATION

import numpy as np
import tree
from typing import Dict, List, Tuple

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from gym.spaces import Dict as GymDict, Discrete, flatten_space, MultiDiscrete
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import one_hot
from ray.rllib.policy.view_requirement import ViewRequirement

from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()

name = "ppo_model"


class PPOModel(RecurrentNetwork):
    """
    Custom LSTM model that implements action masking by setting the logits values of unavailable
    actions with tf.float32.min.

    Can be used for PPO, APPO, IMPALA.

    Data flow (!):
    obs -> forward(forward_rnn()) -> logits
    """

    def __init__(
            self,
            obs_space: GymDict,
            action_space: Discrete,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
            q_hiddens=(256,),
    ):
        """
        :param obs_space: Observation space of the target gym env. This may have an `original_space`
        attribute that specifies how to unflatten the tensor into a ragged tensor.
        :param: action_space: Action space of the target gym env.
        :param: num_outputs (int): Number of output units of the model. Rllib requirement but not used here.
        model_config (ModelConfigDict): Config for the model specify in the trainer config.
        name (str): Name (scope) for the model.
        """

        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        # 1) Pre-process observation space to remove action mask
        orig_obs_space = getattr(obs_space, "original_space", obs_space)
        self.orig_space_wo_mask = orig_obs_space['observation']
        self.obs_dim = int(np.product(flatten_space(self.orig_space_wo_mask).shape))

        # 2) Prepare input: Add prev_action and prev_reward if needed (copy from LSTMModel)
        input_size = 0
        input_size += self.obs_dim

        self.use_prev_action = model_config.get("lstm_use_prev_action", False)
        self.use_prev_reward = model_config.get("lstm_use_prev_reward", False)

        self.action_dim = 0
        self.action_space_struct = get_base_struct_from_space(self.action_space)
        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
            else:
                self.action_dim += int(len(space))

        if self.use_prev_action:
            input_size += self.action_dim

        if self.use_prev_reward:
            input_size += 1

        input_layer = tf.keras.layers.Input(shape=(None, input_size), name="inputs")

        # 3) Preprocess observation with a hidden layer and send to LSTM cell
        hiddens = model_config.get("fcnet_hiddens", [])
        layers = input_layer
        for i in range(len(hiddens)):
            layers = tf.keras.layers.Dense(
                units=hiddens[i],
                activation='relu',
                name="pre_process_hidden_%d" % i)(layers)

        # 4) Build LSTM layer
        self.cell_size = model_config["lstm_cell_size"]

        state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h")
        state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c")
        seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)

        lstm_out, state_h, state_c = tf.keras.layers.LSTM(
            self.cell_size, return_sequences=True, return_state=True, name="lstm"
        )(
            inputs=layers,
            mask=tf.sequence_mask(seq_in),
            initial_state=[state_in_h, state_in_c],
        )

        # 5) Post-process LSTM output with another hidden layer and compute values
        logits = tf.keras.layers.Dense(
            self.action_dim, activation=tf.keras.activations.linear, name="logits"
        )(lstm_out)
        values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)

        # 6) Create the RNN model
        self._rnn_model = tf.keras.Model(
            inputs=[input_layer, seq_in, state_in_h, state_in_c],
            outputs=[logits, values, state_h, state_c],
        )

        # Add prev-a/r to this model's view, if required.
        if self.use_prev_action:
            self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
                SampleBatch.ACTIONS, space=self.action_space, shift=-1
            )
        if self.use_prev_reward:
            self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
                SampleBatch.REWARDS, shift=-1
            )

    @override(RecurrentNetwork)
    def forward(self,
                input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        """
        Compute the model logits via a LSTM.

        Step 1) Compute the LSTM input.
        For the observation, do not use input_dict['obs_flat'] here, because it contains the action_mask, which is
        not used in the model. Instead, compute the flattened observation from the original observation dictionary.
        Step 2) Compute the LSTM output.
        Step 3) Apply mask

        :param input_dict:  Dictionary with keys 'obs' (the original observation dictionary) and
                            'obs_flat' (the flattened observation)
        :param state:       LSTM state (h and c)
        :param seq_lens:    Input sequence lengths

        :return model_out: Tensor with the model output
        :return state:     New LSTM state
        """
        assert seq_lens is not None

        # 1) LSTM input

        # True observation (avoid mask being also passed as an observation)
        actual_obs_list = self._collect_true_observation(self.orig_space_wo_mask, input_dict["obs"]["observation"])
        lstm_input = tf.reshape(tf.concat(actual_obs_list, -1), [-1, self.obs_dim])  # [B*T, obs_dim]

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            prev_a = one_hot(prev_a, self.action_space)  # [B, T, A]
            prev_a = tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])  # [B*T, A]
            prev_a_r.append(prev_a)

        if self.use_prev_reward:
            prev_r = input_dict[SampleBatch.PREV_REWARDS]
            prev_r = tf.reshape(tf.cast(prev_r, tf.float32), [-1, 1])  # [B*T, 1]
            prev_a_r.append(prev_r)

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            lstm_input = tf.concat([lstm_input] + prev_a_r, axis=1)  # [B*T, obs_dim+A+1]

        inputs = add_time_dimension(padded_inputs=lstm_input, seq_lens=seq_lens, framework="tf")  # [B, T, obs_dim+A+1]

        # 2) Output
        logits, new_state = self.forward_rnn(inputs, state, seq_lens)

        # 3) Apply mask
        action_mask = input_dict["obs"]["mask"]
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        masked_logits = logits + inf_mask

        return tf.reshape(masked_logits, [-1, self.action_dim]), new_state

    @override(RecurrentNetwork)
    def forward_rnn(
        self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
    ) -> Tuple[TensorType, List[TensorType]]:
        model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state)
        return model_out, [h, c]

    @override(RecurrentNetwork)
    def get_initial_state(self) -> List[np.ndarray]:
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
        ]

    @override(ModelV2)
    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    @override(ModelV2)
    def import_from_h5(self, h5_file: str) -> None:
        raise NotImplementedError('Import from h5 not implemented.')

    def _collect_true_observation(self, observation_space: dict, observation: dict) -> List[TensorType]:
        flattened_obs = []
        for key in observation_space.keys():
            if isinstance(observation_space[key], GymDict):
                flattened_obs.extend(self._collect_true_observation(observation_space[key], observation[key]))
            else:
                flattened_obs.append(observation[key])
        return flattened_obs