Resuming experiment checkpoint hangs

I am trying to restore an unfinished experiment with this script:

import os
from ray import tune
storage_path = "/ray_results/"
exp_name = "lstm_JPZM_20231024_120553"
experiment_path = os.path.join(storage_path, exp_name)

restored_tuner = tune.Tuner.restore(

But I get the below error message, after which the script hangs.

Failed to read the results for 4 trials:
- /ray_results/lstm_JPZM_20231024_120553/PPO_ssa_env_38b62_00000_0_fcnet_activation=tanh,fcnet_hiddens=100,lstm_state_size=32_2023-10-24_12-05-59
- /ray_results/lstm_JPZM_20231024_120553/PPO_ssa_env_38b62_00001_1_fcnet_activation=tanh,fcnet_hiddens=200,lstm_state_size=32_2023-10-24_12-05-59
- /ray_results/lstm_JPZM_20231024_120553/PPO_ssa_env_38b62_00002_2_fcnet_activation=tanh,fcnet_hiddens=400,lstm_state_size=32_2023-10-24_12-05-59
- /ray_results/lstm_JPZM_20231024_120553/PPO_ssa_env_38b62_00003_3_fcnet_activation=tanh,fcnet_hiddens=800,lstm_state_size=32_2023-10-24_12-05-59
2023-10-24 12:22:42,220	INFO -- Started a local Ray instance. View the dashboard at e[1me[32m127.0.0.1:8265 e[39me[22m

The experiment I am attempting to restore was manually cancelled with 1 trial in progress and the remaining 3 trials pending.

There are a few similar questions on the forum (1, 2, 3), but my none that are having the same issue as me.

@dylan906 Do you have a minimal script that I could use to reproduce this? (The original training script.) I can help investigate this today if so!

@justinvyu The bad news is that the original script to train the checkpoint involves a ton of specific and complex things from my repo that it makes it impracticable to make a minimum working example. The good news if your question got me to try making a MWE, and I was able to make something functional that at least rules out some potential problems.

I am using a custom model and environment. The custom environment is the complicated thing that is impracticable to reproduce. The MWE below uses my custom model and a toy custom environment. The MWE works; I am able to resume an interrupted experiment.

The question now is, why does this not work with the more complex case? I realize that without a full working example of the problem, this question is probably directly unanswerable. But if you can provide any breadcrumbs, that would be super helpful. Do you have any suggestions on where to look that might be causing problems? Presumably because the MWE works, the problem lies in the custom environment. Is there something about the way a trial is prematurely terminated that could affect the restore process?

Script to create checkpoint:

"""Create a checkpoint."""
import random
import string
from ray import air, tune
from ray.rllib.models import ModelCatalog
from mask_repeat_after_me import MaskRepeatAfterMe
from lstm_mask import MaskedLSTM

# %% Script
ModelCatalog.register_custom_model("MaskedLSTM", MaskedLSTM)

param_space = {
    "framework": "torch",
    "env": MaskRepeatAfterMe,
    "model": {
        "custom_model": "MaskedLSTM",
        "custom_model_config": {
            "fcnet_hiddens": [6, 6],
            "fcnet_activation": "relu",
            "lstm_state_size": tune.grid_search([2, 4, 6, 8]),

rand_str = "".join(random.choices(string.ascii_uppercase, k=3))
exp_name = "training_run_" + rand_str

tuner = tune.Tuner(
            "training_iteration": 10,
results =

Restore script:

"""Test to resume experiment level checkpoint."""
import os
from ray import tune
from ray.rllib.models import ModelCatalog
from lstm_mask import MaskedLSTM

# %% Script
ModelCatalog.register_custom_model("MaskedLSTM", MaskedLSTM)

# Replace these 2 lines with checkpoint location
storage_path = "/ray_results/"
exp_name = "training_run_AAA"

experiment_path = os.path.join(storage_path, exp_name)

restored_tuner = tune.Tuner.restore(

Custom model (action mask, based on Ray example):

"""Custom LSTM + Action Mask model."""
from typing import Dict, List, Tuple
from warnings import warn
import gymnasium as gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.utils.typing import TensorType
from torch import all as tensorall
torch, nn = try_import_torch()

# %% Class
class MaskedLSTM(TorchRNN, nn.Module):
    """Fully-connected layers feed into an LSTM layer."""

    def __init__(
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: dict = None,
        name: str = None,
        """Initialize MaskedLSTM model.

            obs_space (gym.spaces.Space): Environment observation space.
            action_space (gym.spaces.Space): Environment action space.
            num_outputs (int): Number of outputs of model. Should be equal to size
                of flattened action_space.
            model_config (dict, optional): Used for Ray defaults. Defaults to {}.
            name (str, optional): Used for inheritance. Defaults to "MaskedLSTM".
            custom_model_kwargs: Configure size of FC net and LSTM layer. Required.

        Expected items in custom_model_kwargs:
            fcnet_hiddens (list[int]): Number and size of FC layers.
            fcnet_activation (str): Activation function for FC layers. See Ray
                SlimFC documentation for recognized args.
            lstm_state_size (int): Size of LSTM layer.
        # Convert space to proper gym space if handed is as a different type
        orig_space = getattr(obs_space, "original_space", obs_space)
        # Size of observations must include only "observations", not "action_mask".
        # Action mask must be 1d and same len as num_outputs.
        # custom_model_kwargs must include "lstm_state_size", "fcnet_hiddens",
        # and "fcnet_activation".
        assert "observations" in orig_space.spaces
        assert "action_mask" in orig_space.spaces
        assert len(orig_space.spaces) == 2
        assert len(orig_space["action_mask"].shape) == 1
        assert (
            orig_space["action_mask"].shape[0] == num_outputs
        ), f"""
        orig_space['action_mask'].shape[0] = {orig_space['action_mask'].shape[0]}\n
        num_outputs = {num_outputs}
        assert "lstm_state_size" in custom_model_kwargs
        assert "fcnet_hiddens" in custom_model_kwargs
        assert "fcnet_activation" in custom_model_kwargs

        lstm_state_size = custom_model_kwargs.get("lstm_state_size")

        # Defaults
        if model_config is None:
            model_config = {}
        if name is None:
            name = "MaskedLSTM"

            obs_space, action_space, num_outputs, model_config, name

        self.obs_size = orig_space["observations"].shape[0]
        # transition layer size: size of output of final hidden layer
        self.trans_layer_size = custom_model_kwargs["fcnet_hiddens"][-1]
        self.lstm_state_size = lstm_state_size

        self.fc_layers = self.makeFCLayers(

        self.lstm = nn.LSTM(
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        self._features = None

    def get_initial_state(self):
        """Initial states of hidden layers are initial states of final FC layer."""
        h = [
  , self.lstm_state_size)
  , self.lstm_state_size)
        return h

    def value_function(self):  # noqa
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    # Override forward() to add an action mask step
    def forward(
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, List[TensorType]]:
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass.
        # When training, input_dict is handed in with an extra nested level from
        # the environment (input_dict["obs"]).
        # Get observations from obs; not observations+action_mask
        flat_inputs = input_dict["obs"]["observations"].float()
        action_mask = input_dict["obs"]["action_mask"]

        # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
        # as input_dict may have extra zero-padding beyond seq_lens.max().
        # Use add_time_dimension to handle this
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
        output, new_state = self.forward_rnn(inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        # Mask raw logits here! Then return masked values
        output = self.maskLogits(logits=output, mask=action_mask)
        return output, new_state

    def maskLogits(self, logits: TensorType, mask: TensorType):
        """Apply mask over raw logits."""
        # Resolve edge case where can pass in mask values <0 and
        # non-integers. Clamp values < 0  to 0, and values > 0 to 1.
        mask_binary = torch.clamp(mask, min=0, max=1)
        mask_binary[mask_binary > 0] = 1

        # check for binary action mask so error doesn't happen in action distribution
        # creation.
        assert all(
            [i in [0, 1] for i in mask_binary.detach().numpy().flatten()]

        if tensorall(mask_binary == 0):
            # check if bad mask passed in
            warn("All actions masked")

        # Mask logits
        inf_mask = torch.clamp(torch.log(mask_binary), min=FLOAT_MIN)
        masked_logits = logits + inf_mask
        return masked_logits

    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        x = nn.functional.relu(self.fc_layers(inputs))
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    def makeFCLayers(
        self, model_config: dict, input_size: int
    ) -> nn.Sequential:
        """Make fully-connected layers.

        See Ray SlimFC for details.

            model_config (dict): {
                "fcnet_hiddens": (list[int]) Numer of hidden layers is number of
                    entries; size of hidden layers is values of entries,
                "fcnet_activation": (str) Recognized activation function
            input_size (int): Input layer size.

            nn.Sequential: Has N layers, where N = len(model_config["fcnet_hiddens"]).
        hiddens = list(model_config.get("fcnet_hiddens", []))
        activation = model_config.get("fcnet_activation")

        self.fc_hiddens = hiddens
        self.fc_activation = activation

        layers = []
        prev_layer_size = input_size

        # Create hidden layers.
        for size in hiddens:
            prev_layer_size = size

        fc_layers = nn.Sequential(*layers)

        return fc_layers

Toy environment (RepeatAfterMe with mask):

"""Action Mask Repeat After Me Env."""
from gymnasium import Env
from gymnasium.spaces import Dict
from gymnasium.spaces.utils import flatten, flatten_space
from numpy import ones
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv

class MaskRepeatAfterMe(Env):
    def __init__(self, config={}):
        self.internal_env = RepeatAfterMeEnv()
        self.observation_space = Dict(
                "observations": flatten_space(
                "action_mask": flatten_space(self.internal_env.action_space),
        self.action_space = self.internal_env.action_space

        self.mask_config = config.get("mask_config", "viable_random")

    def reset(self, *, seed=None, options=None):
        obs, info = self.internal_env.reset()
        new_obs = self._wrapObs(obs)
        self.last_obs = new_obs
        return new_obs, info

    def step(self, action):
        trunc = self._checkMaskViolation(action)
        obs, reward, done, _, info = self.internal_env.step(action)
        new_obs = self._wrapObs(obs)
        self.last_obs = new_obs
        return new_obs, reward, done, trunc, info

    def _wrapObs(self, unwrapped_obs):
        if self.mask_config in ["viable_random"]:
            mask = self.observation_space.spaces["action_mask"].sample()
            mask[0] = 1
        elif self.mask_config == "full_random":
            mask = self.observation_space.spaces["action_mask"].sample()
        elif self.mask_config == "off":
            mask = ones(
                self.observation_space.spaces["action_mask"].shape, dtype=int

        wrapped_obs = {
            "observations": flatten(
                self.internal_env.observation_space, unwrapped_obs
            "action_mask": mask,
        return wrapped_obs

    def _checkMaskViolation(self, action):
        flat_action = flatten(self.action_space, action)
        diff = self.last_obs["action_mask"] - flat_action
        if any([i < 0 for i in diff]):
            truncate = True
            print("mask violation")
            truncate = False

        return truncate

Hmm, nothing really stands out to me in the script except for the restoration path:

Is /ray_results/ the correct path to be restoring from? The default directory should be at ~/ray_results, so you’d want to use storage_path = os.path.expanduser("~/ray_results")

I was able to get it to run by setting this parameter to num_cpus -1 in the restore script. I found this line in an old script I had written, but unfortunately I didn’t write down why I added this line, so I’ll just have to go with the “if it works, it works” reasoning for now.

    os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = str(num_cpus - 1)

1 Like