Rollout storage with callbacks does not capture starting state

When running a callback that returns the observation at each timestep, the callback recorder does not store the starting state of the environment. Somehow this is missed and the first observation is the one seen after the agent has moved for the first time. This initial state is captured by the Gym video recorder, but not by the callback. Is there a reason for this, and is there a way to capture the initial state input? If needed, I can provide saved information about the rollouts in question.

What callback are you using? Best to post code here to debug!

On a side note, I think you can modify the method on_episode_start to include the first timestep of the episode (aka env.reset()).

Here is the code I used. It does not alter the default callback too much. I was able to get the first observation by calling a method I made in my env called gen_obs during the on_episode_start function call. However, if I call the episode.last_observation_for() method at episode start, I get a return of None instead of the first observation. It might make sense for the episode to have the starting observation state so it can be accessed without building custom methods for the environment. I was not able to identify a way to get the first observation through the episode object like I was expecting. If you call episode.last_observation_for() during the on_episode_step callback, that observation corresponds to what was seen after the action taken in that timestep, not before.

   def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
                          policies: Dict[str, Policy],
                          episode: MultiAgentEpisode, env_index: int, **kwargs):
           # Make sure this episode has just been started (only initial obs
           # logged so far).
           assert episode.length == 0, \
               "ERROR: `on_episode_start()` callback should be called right " \
               "after env reset!"
  
  
           # print("episode {} (env-idx={}) started.".format(
           #     episode.episode_id, env_index))
           # print(episode.episode_id)
           self.episode = {}
           self.step_num = 0
           self.first_obs = None
           self.reward_total = 0
           self.episode_num += 1
  
           # Environment observation at start
           env = base_env.get_unwrapped()[0]
           obs = env.gen_obs()
  
           # Get the last observation
           self.first_obs = obs
  
           # Add information for the episode
           self.episode[self.step_num] = {
               "observation": obs,
           }

def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
                        episode: MultiAgentEpisode, env_index: int, **kwargs):           
           self.step_num += 1
  
           # Add information for the next episode observation
           self.episode[self.step_num] = {
               "observation": episode.last_observation_for(),
           }