When running a callback that returns the observation at each timestep, the callback recorder does not store the starting state of the environment. Somehow this is missed and the first observation is the one seen after the agent has moved for the first time. This initial state is captured by the Gym video recorder, but not by the callback. Is there a reason for this, and is there a way to capture the initial state input? If needed, I can provide saved information about the rollouts in question.
What callback are you using? Best to post code here to debug!
On a side note, I think you can modify the method on_episode_start
to include the first timestep of the episode (aka env.reset()).
Here is the code I used. It does not alter the default callback too much. I was able to get the first observation by calling a method I made in my env called gen_obs
during the on_episode_start
function call. However, if I call the episode.last_observation_for()
method at episode start, I get a return of None
instead of the first observation. It might make sense for the episode to have the starting observation state so it can be accessed without building custom methods for the environment. I was not able to identify a way to get the first observation through the episode object like I was expecting. If you call episode.last_observation_for()
during the on_episode_step
callback, that observation corresponds to what was seen after the action taken in that timestep, not before.
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
# Make sure this episode has just been started (only initial obs
# logged so far).
assert episode.length == 0, \
"ERROR: `on_episode_start()` callback should be called right " \
"after env reset!"
# print("episode {} (env-idx={}) started.".format(
# episode.episode_id, env_index))
# print(episode.episode_id)
self.episode = {}
self.step_num = 0
self.first_obs = None
self.reward_total = 0
self.episode_num += 1
# Environment observation at start
env = base_env.get_unwrapped()[0]
obs = env.gen_obs()
# Get the last observation
self.first_obs = obs
# Add information for the episode
self.episode[self.step_num] = {
"observation": obs,
}
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
self.step_num += 1
# Add information for the next episode observation
self.episode[self.step_num] = {
"observation": episode.last_observation_for(),
}