Running the simple script below, where my goal was to check the size of training batches, I realised that things did not behave as I expected. Here I have 1 worker, an environment that always has 10 steps, a batch size of 10 for training (for simplicity I also put 1 sgt iteration and a minibatch of size 10). When printing the actual size of the batch here (len(batch)) I have a a size of 200.
It is worth noting that if I let "batch_mode": "complete_episodes" and put ("rollout_fragment_length": 10), the size is right as if the argument was still considered.
Hmm, yeah, “train_batch_size” is not entirely respected by RLlib here b/c the “rollout_fragment_length” is 200 (default value). So the 1 worker always collects at least 200 steps (and maybe even more if “batch_mode=complete_episode” and it has to finish an episode) from the env before sending the batch back to the driver (for learning).
2 Options here:
I agree it is confusing and we may want to change that into: Collect batch on worker(s) (according to rollout_fragment_length) → split the concated batch (from n workers) into train_batch_size chunks and do n training calls. However, this would have the same effect as the current logic as the worker policies are not updated in between the n training calls (they have “over collected” so to speak).
I think the better solution would be to warn and automatically auto-adjust the rollout_fragment_length in this case: rollout_fragment_length = train_batch_size / num_workers. Something like this.
I agree that if a users sets train_batch_size to something, we would expect that size to arrive in Policy.learn_on_batch
Let’s try this here. Pending tests passing. This should alleviate the confusion about the batch size arriving in Policy.learn_on_batch being different from the user defined train_batch_size.
Note: This could still be the case iff batch_mode=complete_episodes, which is unavoidable for episodes taking longer than rollout_fragment_length.
I am a bit confused about the need to warn and adjust rollout_fragment_length when batch_mode=complete_episode. In the documentation it is stated that for complete_episode "
The rollout_fragment_length setting will be ignored", so getting a warning about this parameter would make me believe my configuration has something wrong with it (i.e. it doesn’t run complete episodes). Would you agree or do I misunderstand the whole slicing flow?
Sorry for reviving this old post. I hope it’s still relevant.