States of Recurrent models for multiple workers/envs

Hi everyone,
I would like to clarify whether the state argument that is passed in forward() function of recurrent models doesn’t mix observations from different envs. I’ll explain my question in details:
I have a custom RNN model, custom env and PPOTrainer. I call ray.tune with 1 worker, 10 envs per worker (rollout_fragment_length=300 and train_batch_size=1500). In addition, because the model is heavy and the env performs large amount of steps per episode, I set remote_worker_envs to True and remote_env_batch_wait_ms to 50.
Now I raised the following concern in my mind: In each forward call, I concat the current observation with the last N observations that are passed in the “state” argument of forward(), exactly like GTrXLNet model in I would like to make sure that the concatenated observations from state arg belong to the same env that input_dict belongs to (in the time axis of course. I don’t mind to have obs from different envs in the batch axis), otherwise there will be a big mess if observations from different environments will be mixed to a single tensor. Does rllib guarantee that property? should I take care of it myself?


Great question @alonit Yes, RLlib should guarantee this! The reason for this is that we keep each episode completely separate from each other (e.g. carrying the last internal-states out) in our Sampler/SampleCollector classes.

e.g. see ray/rllib/evaluation/ and _process_observations.