Init device mesh in pytorch distributed

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version: 2.4
  • Python version: 3.10
  • OS: ubuntu
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected:
    I am planning to torch device mesh when using ray train. However, ray train initializes dist.init_process_group() as default (ref). This is needed since fsdp2 requires a device mesh to customize sharding strategies. Is there a workaround for this?

hi, this is a cool use case. Can you share about about what you want to do with your sharding strategies?

FSDP2 expects device_mesh to get the device placement and infer sharding strategy. at the moment I’m looking at hybrid sharding. However, device_mesh doesn’t seem to be an option in ray train. It doesn’t seem to be a blocker since init_device_mesh already checks if dist.is_initialized() and doesn’t throw error for duplicate initialization.
It would be great to provide user choice to choose between init_device_mesh vs init_process_group