FSDP2 support for PyTorch ray train

Hi, I am trying to have support for fsdp2 in ray train. It seems like the parallel_strategy param in prepare_model() has only option for ddp and fsdp. Is there a way to wrap the model around fsdp2?

Hi! Since torch made fsdp2 a public API only in torch 2.6 (released a week ago), we do not have support for FSDP2 in ray.train.prepare_model yet.

However, ray train is highly flexible and thus you can manually wrap model in FSDP2 until then:

- model = ray.train.prepare_model(model)
+ from torch.distributed.fsdp import fully_shard # torch >= 2.6
+ # fill in any fsdp_kwargs below. 
+ # fsdp2 is in-place, uses current `device` 
+ fully_shard(model, **fsdp_kwargs) 

This is the simplest example where only the entire model is wrapped with fully_shard . For more on fsdp2 api usage you can refer to the API docs: torch.distributed.fsdp.fully_shard — PyTorch 2.6 documentation and the RFC: [RFC] Per-Parameter-Sharding FSDP · Issue #114299 · pytorch/pytorch · GitHub