Trainer is calling reset() even if the trial should have stop

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I need strict control over the calls of the function reset(). My reset() function is stateful.

As an example, imagine each reset() pops a number from the shared set {0,1,2,3}. I want to run exactly 4 episodes in parallel. So I cannot allow more than 4 reset() on my env.

So i use disable_env_checking: True . And I add the stopping condition training_iteration: 1 with timesteps_per_iteration: 10_000 where 10_000 is the size of an episode.

It works fine until it reaches the stopping condition. Then Rllib is doing a final reset() before actually stopping. I get race conditions bugs because of it :frowning:

How do I prevent RLlib to do this final reset() before stopping the trial?

Below you can see the stacktrace (not a bug, just the trace) of the problematic reset:

(EvaluatorTrainable pid=101029) [CRITICAL][misc] File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/workers/default_worker.py", line 238, in <module>
(EvaluatorTrainable pid=101029)     ray.worker.global_worker.main_loop()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/worker.py", line 451, in main_loop
(EvaluatorTrainable pid=101029)     self.core_worker.run_task_loop()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/_private/function_manager.py", line 675, in actor_method_executor
(EvaluatorTrainable pid=101029)     return method(__ray_actor, *args, **kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(EvaluatorTrainable pid=101029)     return method(self, *_args, **_kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/tune/trainable.py", line 360, in train
(EvaluatorTrainable pid=101029)     result = self.step()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(EvaluatorTrainable pid=101029)     return method(self, *_args, **_kwargs)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/work/myrill/core/myrillcore/ray/evaluator_trainable.py", line 63, in step
(EvaluatorTrainable pid=101029)     self.local_worker.sample()
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 845, in sample
(EvaluatorTrainable pid=101029)     batches = [self.input_reader.next()]
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 115, in next
(EvaluatorTrainable pid=101029)     batches = [self.get_data()]
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 288, in get_data
(EvaluatorTrainable pid=101029)     item = next(self._env_runner)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 671, in _env_runner
(EvaluatorTrainable pid=101029)     active_envs, to_eval, outputs = _process_observations(
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 1044, in _process_observations
(EvaluatorTrainable pid=101029)     ] = base_env.try_reset(env_id)
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/env/vector_env.py", line 325, in try_reset
(EvaluatorTrainable pid=101029)     else 0: {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
(EvaluatorTrainable pid=101029) File "/home/ncarrara/anaconda3/envs/myrill/lib/python3.8/site-packages/ray/rllib/env/vector_env.py", line 233, in reset_at
(EvaluatorTrainable pid=101029)     return self.envs[index].reset()

And this is my trainer code:

class EvaluatorTrainable(Trainable):

    def __init__(self, config=None, logger_creator=None):
        super().__init__(config, logger_creator)

    def setup(self, config):
        env_creator = registry._global_registry.get(ENV_CREATOR, config["env"])
        env_config = config["env_config"]
        env_config["nodes"]["env"]["json"]["name"] = "EvaluatorTrainableEnv"
        env_config["nodes"]["env"]["json"]["unremovable_log_color"] = "grey"
        # callbacks = config["callbacks"]
        callbacks = lambda: MyrillCallbacks(
            legacy_callbacks_dict=None,
            report_on_disk_episode_reset_frequency=1,
            episodes_path="/eval_episodes")


        if "num_workers" not in config:
            config["num_workers"] = 0

        self.use_local_worker = int(config["num_workers"]) == 0

        params = dict(
            disable_env_checking=config["disable_env_checking"],
            policy_spec=PolicyWrapper,
            env_creator=env_creator,
            rollout_fragment_length=0,  # batch_mode will force complete episodes anyway
            batch_mode="complete_episodes",
            callbacks=callbacks,
            env_config=env_config,
            policy_config={
                "env_creator": env_creator,
                "env_config": env_config,
                "config": config
            })

        if self.use_local_worker:
            self.local_worker = RolloutWorker(**params)
        else:
            raise Exception("it won't work sorry bro")
            self.remote_workers = [
                RolloutWorker.as_remote().remote(**params)
                for _ in range(int(config["num_workers"]))
            ]

    def step(self):
        if self.use_local_worker:
            self.local_worker.sample()
            return collect_metrics(local_worker=self.local_worker)
        else:
            ray.get([w.sample.remote() for w in self.remote_workers])
            return collect_metrics(remote_workers=self.remote_workers)

Hi @Nicolas_Carrara ,

Not a solution, but an idea for a hack: Try setting config["no_done_at_end"] = True.
This is what triggers the reset and if there is, for some reason, an access reset but you only want one reset overall, then this might do the trick.

1 Like

I want a reset by episode/trajectory, no less no more = ) I’ve changed my way of experimenting, so it is not an issue for me anymore, but I believe we still need more control on how RLlib reset our envs.