Hi everyone,
I am currently the lead developer of Raylight, a ComfyUI extension that allows workflows to run in parallel using Ray. I am looking for advice about model state distribution.
The current issue is that lower end systems do not have enough RAM to store the full state_dict on every worker.
My system can be simplified as follows:
class RayWorker:
def __init__(distributed_config):
self.process_group = group_manager(distributed_config["sequence_parallel"])
dist.init_process_group("nccl", rank, world_size)
def load_model(model_path, fsdp_config):
self.model = load_safetensors(model_path)
self.model = shard_model(self.model)
def run_model(context, additional_config):
y = self.model(context)
A possible solution would be to initialize the model on the meta device and then distribute weights using DCP or a broadcast from rank 0.
However, some pipelines do not use FSDP. Because of the way ComfyUI selects models, it first reads the safetensors file to determine which model class to use. This means the model is loaded before meta initialization is possible. In pipelines that do not use FSDP, each worker ends up loading the full state_dict.
I experimented with DTensor replication and broadcasting, but the results were limited.
I recently looked at the implementation of RDT transport and found it interesting because of the potential benefits. However, I have several questions before attempting a refactor.
First, does using RDT with NCCL or NIXL require creating another process group, or does it manage its own internally? I am concerned about possible parallel degree collisions. For example:
RDT rank * SP rank * FSDP rank > world_size
Could this cause mismatches with the existing process group configuration?
Second, the documentation mentions that RDT only supports tensors. Since state_dict is a Python object containing tensors, does that mean I need to manually recurse through the state_dict and transfer each tensor individually?
Refactoring this part of the system will take significant time, so I want to make sure I understand the expected integration pattern before implementing it.
Any guidance would be appreciated. Thanks!