Hi everyone,
I am reading through the rollout_ops.py
and therein I try to understand how rollouts for remote workers run. At line 82 I find:
rollouts = from_actors(workers.remote_workers())
which gives me ParallelIterator[from_actors[shards=2]]
a ParallelIterator
over a list of _ActorSet
s. As far as I understand the _ActorSet
includes actors (in this case RollOutWorkers
subclassing the ParallelIteratorWorker
. This ParallelIterator
cannot be iterated over (so calling next(rollouts)
does not work - I guess that is because these workers are remote?
Sampling somehow works when batch_across_shards()
is called from the ParallelIterator
which creates a LocalIterator
(on the local worker I guess):
rollouts.batch_across_shards()
Out: LocalIterator[ParallelIterator[from_actors[shards=2]].batch_across_shards()]
On this LocalIterator
I can call next()
which gives me then a list of batches:
next(rollouts.batch_across_shards())
Out: [<SampleBatch, len() = 200>, <SampleBatch, len() = 200>]
I guess because in this case the LocalIterator
collects the batches across the 2 workers?
Now, what I try to understand is: How is this sampling achieved and why does it need the ParallelIteratorWorker? Is it simply a wrapper for the
RollOutWorker` such that the latter one can be orchestrated in a parallel (remote) setting?
Looking into the batch_across_shards()
method I can see that the par_iter_next()
method of the ParallelIteratorWorker
gets called for all workers:
futures = [a.par_iter_next.remote() for a in active]
par_iter_next()
uses the local_it
iterator that gets initialized in par_iter_init()
:
it = LocalIterator(lambda timeout: self.item_generator,
SharedMetrics())
Here the self.item_generator
is an attribute from the ParallelIteratorWorker
that gets generated in the ctor by passing it in as a parameter. In the RolloutWorker
initialization this gets passed in via:
def gen_rollouts():
while True:
yield self.sample()
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
So, there is the sample()
function. Long way to walk
Again: Why is this ParallelIteratorWorker needed? I try to understand the design thoughts of the development team?
Thanks for any participitation in this discussion to understand deeper how RLlib works and gets so fast.