Callbacks.on_episode_step called an extra time during the first episode played (after the first call to env.reset)

I am using RLLib v1.0.0.

During the first episode, the callback on_episode_step is called one time after the reset of the environment first (using fake actions) and then one time per step in the environment.

But during the next episodes played, the callback on_episode_step is only called one time per step in the environment (not at the reset).

Is that the expected behavior?

One problem with that is that we can’t observe the first observation (from the reset) after the first episode.


To my understanding this is because after the first initialization of the env, the obs returned by the reset is returned through the poll method of _MultiAgentEnvState (which will go through the callbacks.on_episode_step):

    class _MultiAgentEnvState:
        def __init__(self, env: MultiAgentEnv):
            assert isinstance(env, MultiAgentEnv)
            self.env = env
            self.initialized = False

        def poll(self) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict,
                                MultiAgentDict, MultiAgentDict]:
            if not self.initialized:
                self.initialized = True
            obs, rew, dones, info = (self.last_obs, self.last_rewards,
                                     self.last_dones, self.last_infos)
            self.last_obs = {}
            self.last_rewards = {}
            self.last_dones = {"__all__": False}
            self.last_infos = {}
            return obs, rew, dones, info 


While after the initialization of the environment (after the first step of the first episode), the reset is done at the end of _process_observations and doesn’t go through the callbacks.on_episode_step.

def _process_observations(
        if hit_horizon and soft_horizon:
                resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list.
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return.
                # If reset is async, we will get its result in some future poll
                episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy: Policy = _get_or_raise(policies, policy_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)

                    item = PolicyEvalData(
                        env_id, agent_id, filtered_obs,
                        episode.last_info_for(agent_id) or {},
                                policy.action_space.sample())), 0.0)

    return active_envs, to_eval, outputs

@sven1977 could you take a look?

Hey @Maxime_Riche , great catch. We should definitely fix this for consistency. I’ll create a PR now …

I guess expected behavior would be to not use the callback on_episode_step right after a reset (since no step has been taken and the on_episode_start callback is used anyways).

PR that fixes the above issue:

