Hello,
Thank you for your great work!
I am encountering an issue with cuda serialization while using PBT.
I setup my tune.Tuner to have multiple workers (e.g., 4), and 1 gpu.
When .fit, I am able to generate checkpoints for the different workers, but at some point early on after (it seems) all trials have been explored for the first number of time_attr, I get the following error:
(RolloutWorker pid=73953) Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.
(RolloutWorker pid=73953) Traceback (most recent call last):
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/ray/_private/serialization.py”, line 404, in deserialize_objects
(RolloutWorker pid=73953) obj = self._deserialize_object(data, metadata, object_ref)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/ray/_private/serialization.py”, line 270, in _deserialize_object
(RolloutWorker pid=73953) return self._deserialize_msgpack_data(data, metadata_fields)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/ray/_private/serialization.py”, line 225, in _deserialize_msgpack_data
(RolloutWorker pid=73953) python_objects = self._deserialize_pickle5_data(pickle5_data)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/ray/_private/serialization.py”, line 215, in _deserialize_pickle5_data
(RolloutWorker pid=73953) obj = pickle.loads(in_band)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/storage.py”, line 337, in _load_from_bytes
(RolloutWorker pid=73953) return torch.load(io.BytesIO(b))
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 1028, in load
(RolloutWorker pid=73953) return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 1256, in _legacy_load
(RolloutWorker pid=73953) result = unpickler.load()
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 1193, in persistent_load
(RolloutWorker pid=73953) wrap_storage=restore_location(obj, location),
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 381, in default_restore_location
(RolloutWorker pid=73953) result = fn(storage, location)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 274, in _cuda_deserialize
(RolloutWorker pid=73953) device = validate_cuda_device(location)
(RolloutWorker pid=73953) File “/usr/local/lib/python3.10/site-packages/torch/serialization.py”, line 258, in validate_cuda_device
(RolloutWorker pid=73953) raise RuntimeError('Attempting to deserialize object on a CUDA ’
(RolloutWorker pid=73953) RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.
This is the snippet for the tuner
tuner = tune.Tuner(
trainable=add_mixins(config_module.algorithm, [SerializableTorchTrainerMixin]),
param_space=config_module.config,
tune_config=tune.TuneConfig(
metric="episode_reward_mean",
scheduler=scheduler,
num_samples=config_module.hyperparameters_samples,
),
run_config=train.RunConfig(
stop={"timesteps_total": config_module.total_timestep},
progress_reporter=improved_reporter,
name=experiment_name,
callbacks=[TBXLoggerCallback(), CSVLoggerCallback(), JsonLoggerCallback()],
checkpoint_config=CheckpointConfig(checkpoint_frequency=config_module.checkpoint_freq,
checkpoint_at_end=True))
)
results = tuner.fit()
and the scheduler:
scheduler_obj = sched.PopulationBasedTraining(
time_attr=config_module.time_attr_pbt, # name of the metric used to perform a new resampling
mode=‘max’,
log_config=True,
resample_probability=0.25,
hyperparam_mutations=hyperparameters_mutations,
synch=True,
perturbation_interval=config_module.perturbation_interval,
)
It’s kind of tough to properly debug this issue since I cant put break points here and there and this error does not occur when num_gpus is set to 0. So if you have any hint on what is going on and how I can solve it, I would really appreciate it!
Using py3.10, ray 2.9.0 and torch with cuda 118.
Stefano