How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
I need strict control over the calls of the function reset(). My reset() function is stateful.
As an example, imagine each reset() pops a number from the shared set {0,1,2,3}. I want to run exactly 4 episodes in parallel. So I cannot allow more than 4 reset() on my env.
So i use disable_env_checking: True
. And I add the stopping condition training_iteration: 1
with timesteps_per_iteration: 10_000
where 10_000
is the size of an episode.
It works fine until it reaches the stopping condition. Then Rllib is doing a final reset() before actually stopping. I get race conditions bugs because of it
How do I prevent RLlib to do this final reset() before stopping the trial?
Below you can see the stacktrace (not a bug, just the trace) of the problematic reset:
(EvaluatorTrainable pid=101029) [CRITICAL][misc] File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/workers/default_worker.py", line 238, in <module>
(EvaluatorTrainable pid=101029) ray.worker.global_worker.main_loop()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/worker.py", line 451, in main_loop
(EvaluatorTrainable pid=101029) self.core_worker.run_task_loop()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/_private/function_manager.py", line 675, in actor_method_executor
(EvaluatorTrainable pid=101029) return method(__ray_actor, *args, **kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(EvaluatorTrainable pid=101029) return method(self, *_args, **_kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/tune/trainable.py", line 360, in train
(EvaluatorTrainable pid=101029) result = self.step()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(EvaluatorTrainable pid=101029) return method(self, *_args, **_kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/work/myrill/core/myrillcore/ray/evaluator_trainable.py", line 63, in step
(EvaluatorTrainable pid=101029) self.local_worker.sample()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 845, in sample
(EvaluatorTrainable pid=101029) batches = [self.input_reader.next()]
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 115, in next
(EvaluatorTrainable pid=101029) batches = [self.get_data()]
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 288, in get_data
(EvaluatorTrainable pid=101029) item = next(self._env_runner)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 671, in _env_runner
(EvaluatorTrainable pid=101029) active_envs, to_eval, outputs = _process_observations(
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 1044, in _process_observations
(EvaluatorTrainable pid=101029) ] = base_env.try_reset(env_id)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/env/vector_env.py", line 325, in try_reset
(EvaluatorTrainable pid=101029) else 0: {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/env/vector_env.py", line 233, in reset_at
(EvaluatorTrainable pid=101029) return self.envs[index].reset()
And this is my trainer code:
class EvaluatorTrainable(Trainable):
def __init__(self, config=None, logger_creator=None):
super().__init__(config, logger_creator)
def setup(self, config):
env_creator = registry._global_registry.get(ENV_CREATOR, config["env"])
env_config = config["env_config"]
env_config["nodes"]["env"]["json"]["name"] = "EvaluatorTrainableEnv"
env_config["nodes"]["env"]["json"]["unremovable_log_color"] = "grey"
# callbacks = config["callbacks"]
callbacks = lambda: MyrillCallbacks(
legacy_callbacks_dict=None,
report_on_disk_episode_reset_frequency=1,
episodes_path="/eval_episodes")
if "num_workers" not in config:
config["num_workers"] = 0
self.use_local_worker = int(config["num_workers"]) == 0
params = dict(
disable_env_checking=config["disable_env_checking"],
policy_spec=PolicyWrapper,
env_creator=env_creator,
rollout_fragment_length=0, # batch_mode will force complete episodes anyway
batch_mode="complete_episodes",
callbacks=callbacks,
env_config=env_config,
policy_config={
"env_creator": env_creator,
"env_config": env_config,
"config": config
})
if self.use_local_worker:
self.local_worker = RolloutWorker(**params)
else:
raise Exception("it won't work sorry bro")
self.remote_workers = [
RolloutWorker.as_remote().remote(**params)
for _ in range(int(config["num_workers"]))
]
def step(self):
if self.use_local_worker:
self.local_worker.sample()
return collect_metrics(local_worker=self.local_worker)
else:
ray.get([w.sample.remote() for w in self.remote_workers])
return collect_metrics(remote_workers=self.remote_workers)