Cannot checkpoint a simple model

1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.

2. Environment:

  • Ray version: 2.45.0
  • Python version: 3.12.10
  • OS: linux
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant): torch 2.8.0.dev20250414+cu128

3. What happened vs. what you expected:

  • Expected: Be able to save checkpoints
  • Actual: ray crashes on save

Hi !

I have been trying to experiment with toy example and I am struggling to save checkpoints. It behaves fine if the learner is a cpu but crashes when it is a gpu. The attached script makes it reproduce every time.

Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Traceback (most recent call last):
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 460, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 317, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 272, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 260, in _deserialize_pickle5_data
    obj = pickle.loads(in_band, buffers=buffers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/storage.py", line 530, in _load_from_bytes
    return torch.load(io.BytesIO(b), weights_only=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1549, in load
    return _legacy_load(
           ^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1807, in _legacy_load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1742, in persistent_load
    obj = restore_location(obj, location)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 698, in default_restore_location
    result = fn(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 636, in _deserialize
    device = _validate_device(location, backend_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 605, in _validate_device
    raise RuntimeError(
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
2025-05-07 13:23:10,305	ERROR actor_manager.py:873 -- Ray error (System error: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
traceback: Traceback (most recent call last):
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 460, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 317, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 272, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/ray/_private/serialization.py", line 260, in _deserialize_pickle5_data
    obj = pickle.loads(in_band, buffers=buffers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/storage.py", line 530, in _load_from_bytes
    return torch.load(io.BytesIO(b), weights_only=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1549, in load
    return _legacy_load(
           ^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1807, in _legacy_load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 1742, in persistent_load
    obj = restore_location(obj, location)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 698, in default_restore_location
    result = fn(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 636, in _deserialize
    device = _validate_device(location, backend_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxxx/anaconda3/envs/nn/lib/python3.12/site-packages/torch/serialization.py", line 605, in _validate_device
    raise RuntimeError(
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
), taking actor 0 out of service.

The full script is:

#!/usr/bin/env python3

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override

class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.fc = nn.Sequential(
            nn.Linear(obs_space.shape[0], 64),
            nn.ReLU(),
            nn.Linear(64, num_outputs)
        )
        self._value_branch = nn.Sequential(
            nn.Linear(obs_space.shape[0], 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self._value_out = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        x = self.fc(input_dict["obs"])
        self._value_out = self._value_branch(input_dict["obs"]).squeeze(1)
        return x, state

    @override(ModelV2)
    def value_function(self):
        return self._value_out

if __name__ == '__main__':
    from ray.rllib.algorithms.ppo import PPO, PPOConfig
    from ray.rllib.models import ModelCatalog
    from ray.tune import CheckpointConfig, Tuner, TuneConfig, RunConfig
    ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
    config = (
        PPOConfig()
        .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
        .environment("CartPole-v1")
        .env_runners(
            num_env_runners=8,
            num_envs_per_env_runner=1,
        )
        #.learners(num_learners=1, num_gpus_per_learner=0) # This works
        .learners(num_learners=1, num_gpus_per_learner=1) # This does not work
        .rl_module(model_config={"custom_model": "my_torch_model"})
    )
    stopping_criteria = {"training_iteration": 100, "env_runners/episode_return_mean": 50}
    tuner = Tuner(
        PPO,
        tune_config=TuneConfig(num_samples=1),
        param_space=config,
        run_config=RunConfig(
            stop=stopping_criteria,
            checkpoint_config=CheckpointConfig(
                checkpoint_frequency=1,
                checkpoint_at_end=True,
            ),
        )
    )
    results = tuner.fit()

I have checked that torch.cuda.is_available() is True in the same environment so I am not sure what is wrong here. Do you have an idea of what is going on ?

Note that I am using a fairly recent version of torch to support latest GPU.

Sometimes, Ray workers might not have CUDA/GPU allocated to them depending on your config. In your PPOConfig can you try adding resources(num_gpus=0.1) ? I think that’ll add the GPUs to the workers and then maybe it will stop throwing the CUDA error and then you can try enabling .learners(num_learners=1, num_gpus_per_learner=1) again.

Hi Christina,

Thanks for the suggestion but adding .resources(num_gpus=0.1) did not change anything. I still have the same issue. I also tried num_gpus_per_env_runner=0.1 and lower the rest so it fits my cpu/gpu. The issue remains the same. Any other ideas ?

I have the same problem with a very similar training configuration on version Ray 2.46, Python 3.13, CUDA 12.0