Find index of env in DefaultCallbacks

I have been using DefaultCallbacks as shown here and manage to log my data in the on_episode_end method.

def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                       policies: Dict[str, Policy], episode: MultiAgentEpisode,
                       env_index: int, **kwargs):
        env = base_env.get_unwrapped()[env_index].unwrapped
        episode.custom_metrics[f"env{env_index}/some_value"] = env.some_value

I’m running PPO with num_workers=4 and wanted to log some values from the environments individually (I have some random driving process that is seeded and thus will make the environments evolve differently on each worker).

What I did not realise until recently is that the env_index variable is always zero, I had assumed it would contain an index which I should use to index into base_env to get the environment that had triggered the end of episode call.

This might explain a lot of confusing data I have looked at, unfortunately I have always just defaulted to looking at environment 0 in tensorboard, and I didn’t realise until now that this was the data from all environments all jumbled together.

I see that env_index seems to be obsoleted, is this what it was once used for and maybe it just broke when I updated the ray version a while ago? Can I still somehow get to know which of all my envs are currently logging a value? Or should I just keep a unique id in each of them?

Hi @albheim,

The env_index is used for vector envs to distinguish between multiple environments on the same worker. This is determined based on the value of num_envs_per_worker. The worker has a worker_indexmember that you could use


1 Like

Just to clarify for anyone else showing up here, the corresponding code in the DefaultCallbacks would look something like

def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                       policies: Dict[str, Policy], episode: MultiAgentEpisode,
                       env_index: int, **kwargs):
        env = base_env.get_unwrapped()[0].unwrapped
        worker_index = worker.worker_index
        episode.custom_metrics[f"env{worker_index}/some_value"] = env.some_value