Using Ray RDT with existing distributed process groups (FSDP/SP)

Hi everyone,

I am currently the lead developer of Raylight, a ComfyUI extension that allows workflows to run in parallel using Ray. I am looking for advice about model state distribution.

The current issue is that lower end systems do not have enough RAM to store the full state_dict on every worker.

My system can be simplified as follows:

class RayWorker:
    def __init__(distributed_config):
        self.process_group = group_manager(distributed_config["sequence_parallel"])
        dist.init_process_group("nccl", rank, world_size)

    def load_model(model_path, fsdp_config):
        self.model = load_safetensors(model_path)
        self.model = shard_model(self.model)

    def run_model(context, additional_config):
        y = self.model(context)

A possible solution would be to initialize the model on the meta device and then distribute weights using DCP or a broadcast from rank 0.

However, some pipelines do not use FSDP. Because of the way ComfyUI selects models, it first reads the safetensors file to determine which model class to use. This means the model is loaded before meta initialization is possible. In pipelines that do not use FSDP, each worker ends up loading the full state_dict.

I experimented with DTensor replication and broadcasting, but the results were limited.

I recently looked at the implementation of RDT transport and found it interesting because of the potential benefits. However, I have several questions before attempting a refactor.

First, does using RDT with NCCL or NIXL require creating another process group, or does it manage its own internally? I am concerned about possible parallel degree collisions. For example:

RDT rank * SP rank * FSDP rank > world_size

Could this cause mismatches with the existing process group configuration?

Second, the documentation mentions that RDT only supports tensors. Since state_dict is a Python object containing tensors, does that mean I need to manually recurse through the state_dict and transfer each tensor individually?

Refactoring this part of the system will take significant time, so I want to make sure I understand the expected integration pattern before implementing it.

Any guidance would be appreciated. Thanks!

  1. Process Groups: RDT (Ray Direct Transport) with NCCL or NIXL requires you to create a separate collective group for the actors involved in RDT transfers; it does not automatically reuse your existing PyTorch process groups. Each RDT collective group is independent, and actors can only be in one collective group per tensor transport at a time. You must ensure that the group membership and parallelism do not conflict with your existing process groups, as mismatches (e.g., RDT group size × SP × FSDP > world_size) can cause issues. See RDT limitations and collective group creation.

  2. Tensor Support: RDT only supports transferring torch.Tensor objects, not arbitrary Python objects. If you want to use RDT to distribute a state_dict, you must manually extract and transfer each tensor in the state_dict. You cannot send the entire state_dict as a single object via RDT; you need to recurse through it and handle each tensor individually. See RDT limitations and discussion on tensor-only support.

Would you like more detail on how to implement this pattern or on handling group membership safely?

Sources:

Hint: Mention @RunLLM in the post for followups.