Multithreading ray python actor with torch.distributed

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version: 2.40.0
  • Python version: 3.12.7
  • OS: linux kernel 5.10.0
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected: I expect to init torch distributed process group inside a ray actor, and meanwhile I want to utilize multi-threading to accelerate other following steps.
  • Actual: I cannot control which thread my remote function will run on, so init_process_group might run on thread 1, but broadcast runs on thread 2. In this case torch settings are not shared among all threads
@ray.remote(max_concurrency=2)
class Actor:
    def __init__(self):
        pass
    
    def initialize(self):
        torch.distributed.init_process_group(...)
        
    def run():
        torch.distributed.broadcast(...)

I found a workaround by limiting torch-related methods to a single-thread concurrency_group, like

@ray.remote(max_concurrency=2,concurrency_group={"torch":1})
class Actor:
    def __init__(self):
        pass
    
    @ray.method(concurrency_group="torch")
    def initialize(self):
        torch.distributed.init_process_group(...)

    @ray.method(concurrency_group="torch")
    def run():
        torch.distributed.broadcast(...)

But it’s very likely that torch operations are hidden deep inside the methods, and developers can easily forget to add the ray.method decorator. Any good practice on similar questions? For example, can I make init_process_group work for all threads?

Hi tianyi-ge! Welcome to the Ray community :slight_smile: I did a bit of searching in the PyTorch docs and so maybe this can help you out.

So from reading the docs, I don’t think it’s possible to make init_process_group() work across all threads, it’s essentially bound to the thread that initialized it. If you call it in one thread but later try to use collectives like broadcast() from another, you’ll probably run into unexpected behavior.

Using Ray’s concurrency groups is definitely one of the ways to solve the issue like you mentioned, or setting max_concurrency=1 which I know isn’t ideal for some use cases.

Thanks christina!

From the docs, I guess torch currently does not work with a “thread pool” as provided by ray actor. One has to have fully control of all threads and decide which thread the torch method should run on. I’m wondering if a nested multi-processing inside ray actor is a good idea? I would leave this issue open for some days in case other fellows have great ideas