I’m getting an error when trying to train on a trainer loaded from a checkpoint.
How to reproduce the error:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # Fix to stop OMP: Error #15:
from ray.rllib.agents.ddpg import DDPGTrainer
import gym
basic_config = {
"env": "MountainCarContinuous-v0",
"framework": "torch"}
# Train policy
trainer = DDPGTrainer(config=basic_config)
iters = 3
for i in range(iters):
output = trainer.train()
trainer.save("test_model")
trainer.restore("test_model/checkpoint_000003/checkpoint-3")
trainer.train()
Error message:
Traceback (most recent call last):
File "c:\Users\ereikvi\Documents\augment_expert_policy_rl\ddpg_test.py", line 77, in <module>
trainer.train()
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\tune\trainable.py", line 315, in train
result = self.step()
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 982, in step
raise e
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 963, in step
step_attempt_results = self.step_attempt()
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 1042, in step_attempt
step_results = self._exec_plan_or_training_iteration_fn()
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\agents\trainer.py", line 1962, in _exec_plan_or_training_iteration_fn
results = next(self.train_exec_impl)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 756, in __next__
return next(self.built_iterator)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 1075, in build_union
item = next(it)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 756, in __next__
return next(self.built_iterator)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\execution\rollout_ops.py", line 90, in sampler
yield workers.local_worker().sample()
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 761, in sample
batches = [self.input_reader.next()]
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 104, in next
batches = [self.get_data()]
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 266, in get_data
item = next(self._env_runner)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 657, in _env_runner
eval_results = _do_policy_eval(
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\evaluation\sampler.py", line 1077, in _do_policy_eval
policy.compute_actions_from_input_dict(
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\policy\torch_policy.py", line 294, in compute_actions_from_input_dict
return self._compute_action_helper(input_dict, state_batches,
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
return func(self, *a, **k)
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\policy\torch_policy.py", line 948, in _compute_action_helper
self.exploration.get_exploration_action(
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\exploration\gaussian_noise.py", line 98, in get_exploration_action
return self._get_torch_exploration_action(action_distribution,
File "C:\Users\ereikvi\Miniconda3\envs\test\lib\site-packages\ray\rllib\utils\exploration\ornstein_uhlenbeck_noise.py", line 176, in _get_torch_exploration_action
mean=torch.zeros(self.ou_state.size()), std=1.0) \
TypeError: 'int' object is not callable
Does anyone know how to solve this issue? Thanks in advance
Edit: I noticed that self.ou_state is a np.array, and that calling .size() on a np.array is what causes the error. It seems to me that self.ou_state is assigned the wrong type for some reason when loading from the checkpoint and resuming training.