Running a list of functions with limited parallelism and autoscaling

Hello! I am looking for a way to run a list of remote functions but with limited parallelism (N). As soon as a remote function finishes, another remote function should get scheduled until all remote functions have finished. At most N remote functions should be in flight at a time.

In the docs, I found ray.util.multiprocessing.Pool (Distributed multiprocessing.Pool — Ray v1.9.2), however, running something like Pool(N) will fail if the cluster does not have N CPUs. Is there a way I can instantiate a Pool with N processes, and have the cluster autoscale to that amount while also processing functions with CPUs that are available?

example pseudo-code

data = [x for x in range(1000)]
# assume cluster has 20 CPUs but can scale up to 100.
pool = Pool(100) # cluster begins to scale to 100 CPUs.
results = pool.imap(f, data)
# cluster runs `f` 1000 times, with at most 100 at the same time.

The other docs I found is to use ray.wait (Pattern: Using ray.wait to limit the number of in-flight tasks — Ray v1.9.2), but this has a drawback of running in fixed chunks.

My current workaround is to implement ray.wait with a queue. (I use a PriorityQueue to preserve order). Is there a better way?

@ray.remote
def f(index, d):
    result = compute(d)
    return (index, result) 

parallelism = 100
data = [x for x in range(1000)]
finished_results_q = queue.PriorityQueue()
in_progress_runs = set()
index = 0
while index < len(data):
    if len(in_progress_runs) < parallelism:
        in_progress_runs.add(f.remote(index, data[index]))
        index += 1
        continue
    time.sleep(1)
    finished_result_ref, _ = ray.wait(list(in_progress_runs))
    finished_results_q.put(ray.get(finished_result_ref))
    in_progress_runs.remove(finished_result_ref)
for result in ray.get(list(in_progress_runs)):
    finished_results_q.put(result)
results = finished_results_q.queue

I think ray.wait should work in this case? I will try writing an example code, and lmk if there’s a missing piece here.

N = 10

@ray.remote
def f():
    pass

tasks = [f.remote() for _ in range(N)]

while tasks:
     done, tasks = ray.wait(tasks) # it will return every time 1 task is finished
     assert len(done) == 1
     result = ray.get(done)
     # Do something with result
     tasks.append(f.remote())

Thank You! I did not think of that.