[RLlib] Batch size for complete_episodes issue

Hi everyone,

Running the simple script below, where my goal was to check the size of training batches, I realised that things did not behave as I expected. Here I have 1 worker, an environment that always has 10 steps, a batch size of 10 for training (for simplicity I also put 1 sgt iteration and a minibatch of size 10). When printing the actual size of the batch here (len(batch)) I have a a size of 200.

It is worth noting that if I let "batch_mode": "complete_episodes" and put ("rollout_fragment_length": 10), the size is right as if the argument was still considered.

batch_size = 10
class TestEnv(gym.Env):
    def __init__(self, config=None):
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(low=-1, high=1, shape=(2,))
        self.s = 0
    def step(self, action):
        self.s += 1
        return np.random.rand(2), 1, self.s >= batch_size, {}
    def reset(self):
        self.s = 0
        return np.random.rand(2)
ray.init(num_cpus=1, local_mode=True)
config = dict(
        "env": TestEnv,
        "num_gpus": 0,
        "framework": "tfe",
        "num_workers": 1,
        "batch_mode": "complete_episodes",
        "train_batch_size": batch_size,
        "sgd_minibatch_size": batch_size,
        "num_sgd_iter": 1,
tune.run("PPO", config=config)

Thanks for you help

Hey @tibogiss , thanks for the post :slight_smile:

Hmm, yeah, “train_batch_size” is not entirely respected by RLlib here b/c the “rollout_fragment_length” is 200 (default value). So the 1 worker always collects at least 200 steps (and maybe even more if “batch_mode=complete_episode” and it has to finish an episode) from the env before sending the batch back to the driver (for learning).

2 Options here:

  1. I agree it is confusing and we may want to change that into: Collect batch on worker(s) (according to rollout_fragment_length) → split the concated batch (from n workers) into train_batch_size chunks and do n training calls. However, this would have the same effect as the current logic as the worker policies are not updated in between the n training calls (they have “over collected” so to speak).
  2. I think the better solution would be to warn and automatically auto-adjust the rollout_fragment_length in this case: rollout_fragment_length = train_batch_size / num_workers. Something like this.

I agree that if a users sets train_batch_size to something, we would expect that size to arrive in Policy.learn_on_batch

Let’s try this here. Pending tests passing. This should alleviate the confusion about the batch size arriving in Policy.learn_on_batch being different from the user defined train_batch_size.
Note: This could still be the case iff batch_mode=complete_episodes, which is unavoidable for episodes taking longer than rollout_fragment_length.

Hi @sven1977 ,

I am a bit confused about the need to warn and adjust rollout_fragment_length when batch_mode=complete_episode. In the documentation it is stated that for complete_episode "
The rollout_fragment_length setting will be ignored", so getting a warning about this parameter would make me believe my configuration has something wrong with it (i.e. it doesn’t run complete episodes). Would you agree or do I misunderstand the whole slicing flow?

Sorry for reviving this old post. I hope it’s still relevant.

1 Like

Hey @2dm , thanks for the catch. It’s not entirely correct that the rollout_fragment_length is ignored in case batch_mode="complete_episodes". I’ll fix that right away in the docs.

If batch_mode=complete_episodes:

  • RolloutWorker.sample() will run for at least 1 episode.
  • RolloutWorker.sample() will always run for n complete episodes. The returned batch will thus always have dones[-1] == True.
  • RolloutWorker.sample() will at least run for rollout_fragment_length timesteps. # ← will add this to the docs; seems to be missing there.
1 Like

Doc-fix PR: https://github.com/ray-project/ray/pull/22074

This topic was automatically closed 24 hours after the last reply. New replies are no longer allowed.