How do remote workers sample?

Hi everyone,

I am reading through the rollout_ops.py and therein I try to understand how rollouts for remote workers run. At line 82 I find:

rollouts = from_actors(workers.remote_workers())

which gives me ParallelIterator[from_actors[shards=2]] a ParallelIterator over a list of _ActorSets. As far as I understand the _ActorSet includes actors (in this case RollOutWorkers subclassing the ParallelIteratorWorker. This ParallelIterator cannot be iterated over (so calling next(rollouts) does not work - I guess that is because these workers are remote?

Sampling somehow works when batch_across_shards() is called from the ParallelIteratorwhich creates a LocalIterator (on the local worker I guess):

rollouts.batch_across_shards()
Out: LocalIterator[ParallelIterator[from_actors[shards=2]].batch_across_shards()]

On this LocalIterator I can call next() which gives me then a list of batches:

next(rollouts.batch_across_shards())
Out: [<SampleBatch, len() = 200>, <SampleBatch, len() = 200>]

I guess because in this case the LocalIterator collects the batches across the 2 workers?

Now, what I try to understand is: How is this sampling achieved and why does it need the ParallelIteratorWorker? Is it simply a wrapper for the RollOutWorker` such that the latter one can be orchestrated in a parallel (remote) setting?

Looking into the batch_across_shards() method I can see that the par_iter_next() method of the ParallelIteratorWorker gets called for all workers:

futures = [a.par_iter_next.remote() for a in active]

par_iter_next() uses the local_it iterator that gets initialized in par_iter_init():

it = LocalIterator(lambda timeout: self.item_generator,
                   SharedMetrics())

Here the self.item_generator is an attribute from the ParallelIteratorWorker that gets generated in the ctor by passing it in as a parameter. In the RolloutWorker initialization this gets passed in via:

def gen_rollouts():
            while True:
                yield self.sample()

ParallelIteratorWorker.__init__(self, gen_rollouts, False)

So, there is the sample() function. Long way to walk :slight_smile:

Again: Why is this ParallelIteratorWorker needed? I try to understand the design thoughts of the development team?

Thanks for any participitation in this discussion to understand deeper how RLlib works and gets so fast.