Hi, I am trying to have support for fsdp2 in ray train. It seems like the parallel_strategy
param in prepare_model()
has only option for ddp
and fsdp
. Is there a way to wrap the model around fsdp2
?
Hi! Since torch
made fsdp2
a public API only in torch 2.6 (released a week ago), we do not have support for FSDP2 in ray.train.prepare_model
yet.
However, ray train is highly flexible and thus you can manually wrap model
in FSDP2 until then:
- model = ray.train.prepare_model(model)
+ from torch.distributed.fsdp import fully_shard # torch >= 2.6
+ # fill in any fsdp_kwargs below.
+ # fsdp2 is in-place, uses current `device`
+ fully_shard(model, **fsdp_kwargs)
This is the simplest example where only the entire model is wrapped with fully_shard
. For more on fsdp2 api usage you can refer to the API docs: torch.distributed.fsdp.fully_shard — PyTorch 2.6 documentation and the RFC: [RFC] Per-Parameter-Sharding FSDP · Issue #114299 · pytorch/pytorch · GitHub