To my understanding this is because after the first initialization of the env, the obs returned by the reset is returned through the poll method of _MultiAgentEnvState (which will go through the callbacks.on_episode_step):
class _MultiAgentEnvState:
def __init__(self, env: MultiAgentEnv):
assert isinstance(env, MultiAgentEnv)
self.env = env
self.initialized = False
def poll(self) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict,
MultiAgentDict, MultiAgentDict]:
if not self.initialized:
self.reset()
self.initialized = True
obs, rew, dones, info = (self.last_obs, self.last_rewards,
self.last_dones, self.last_infos)
self.last_obs = {}
self.last_rewards = {}
self.last_dones = {"__all__": False}
self.last_infos = {}
return obs, rew, dones, info
[...]
While after the initialization of the environment (after the first step of the first episode), the reset is done at the end of _process_observations and doesn’t go through the callbacks.on_episode_step.
def _process_observations(
[...]
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
else:
del active_episodes[env_id]
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list.
if horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
elif resetted_obs != ASYNC_RESET_RETURN:
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll
episode: MultiAgentEpisode = active_episodes[env_id]
if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs,
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
# type: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = episode.policy_for(agent_id)
policy: Policy = _get_or_raise(policies, policy_id)
prep_obs: EnvObsType = _get_or_raise(
preprocessors, policy_id).transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
obs_filters, policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
item = PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id),
np.zeros_like(
flatten_to_single_ndarray(
policy.action_space.sample())), 0.0)
to_eval[policy_id].append(item)
return active_envs, to_eval, outputs