Crash when calling .train() after loading from checkpoint

I’m getting an error when trying to train on a trainer loaded from a checkpoint.

How to reproduce the error:

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # Fix to stop OMP: Error #15:
from ray.rllib.agents.ddpg import DDPGTrainer
import gym

basic_config = {
    "env": "MountainCarContinuous-v0",
    "framework": "torch"}

# Train policy
trainer = DDPGTrainer(config=basic_config)

iters = 3
for i in range(iters):
    output = trainer.train()
    trainer.save("test_model")

trainer.restore("test_model/checkpoint_000003/checkpoint-3")

trainer.train()

Error message:

Traceback (most recent call last):
  File "c:\Users\ereikvi\Documents\augment_expert_policy_rl\ddpg_test.py", line 77, in <module>
    trainer.train()
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\tune\trainable.py", line 315, in train
    result = self.step()
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 982, in step
    raise e
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 963, in step
    step_attempt_results = self.step_attempt()
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 1042, in step_attempt
    step_results = self._exec_plan_or_training_iteration_fn()
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 1962, in _exec_plan_or_training_iteration_fn
    results = next(self.train_exec_impl)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 1075, in build_union
    item = next(it)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\execution\rollout_ops.py", line 90, in sampler
    yield workers.local_worker().sample()
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 761, in sample
    batches = [self.input_reader.next()]
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 104, in next
    batches = [self.get_data()]
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 266, in get_data
    item = next(self._env_runner)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 657, in _env_runner
    eval_results = _do_policy_eval(
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 1077, in _do_policy_eval
    policy.compute_actions_from_input_dict(
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\policy\torch_policy.py", line 294, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\policy\torch_policy.py", line 948, in _compute_action_helper
    self.exploration.get_exploration_action(
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\exploration\gaussian_noise.py", line 98, in get_exploration_action
    return self._get_torch_exploration_action(action_distribution,
  File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\exploration\ornstein_uhlenbeck_noise.py", line 176, in _get_torch_exploration_action
    mean=torch.zeros(self.ou_state.size()), std=1.0) \
TypeError: 'int' object is not callable

Does anyone know how to solve this issue? Thanks in advance :slight_smile:

Edit: I noticed that self.ou_state is a np.array, and that calling .size() on a np.array is what causes the error. It seems to me that self.ou_state is assigned the wrong type for some reason when loading from the checkpoint and resuming training.

1 Like

Hey @viktor_em , thanks for raising this issue. There is indeed a bug in RLlib’s OrnsteinUhlenbeck exploration component when it comes to restoring from a previously saved state.

I’m providing a fix PR here: https://github.com/ray-project/ray/pull/22245

1 Like

That seems to solve it, thank you!