I am afraid, it didn’t work for me.
Please find below the snippet of logs:
----------- Begin ----------------------
~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/ray/train/torch/train_loop_utils.py:31
29 FullyShardedDataParallel = None
30 else:
—> 31 from torch.distributed.fsdp import FullyShardedDataParallel
33 try:
34 from torch.profiler import profile
pytorch/torch/distributed/fsdp/init.py:1
----> 1 from ._flat_param import FlatParameter as FlatParameter
2 from .fully_sharded_data_parallel import (
3 BackwardPrefetch,
4 CPUOffload,
(…)
18 StateDictType,
19 )
21 all = [
22 “BackwardPrefetch”,
23 “CPUOffload”,
(…)
37 “StateDictType”,
38 ]
~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/ray/train/torch/init.py:15
13 from ray.train.torch.torch_predictor import TorchPredictor
14 from ray.train.torch.torch_trainer import TorchTrainer
—> 15 from ray.train.torch.train_loop_utils import (
16 accelerate,
17 backward,
18 enable_reproducibility,
19 get_device,
20 get_devices,
21 prepare_data_loader,
22 prepare_model,
23 prepare_optimizer,
24 )
26 all = [
27 “TorchTrainer”,
28 “TorchCheckpoint”,
(…)
39 “TorchDetectionPredictor”,
40 ]
pytorch/torch/distributed/fsdp/_flat_param.py:30
28 import torch.nn.functional as F
29 from torch import Tensor
—> 30 from torch.distributed.fsdp._common_utils import (
31 _FSDPDeviceHandle,
32 _named_parameters_with_duplicates,
33 _no_dispatch_record_stream,
34 _set_fsdp_flattened,
35 HandleTrainingState,
36 )
37 from torch.distributed.utils import (
38 _alloc_storage,
39 _data_ptr_allocated,
40 _free_storage,
41 _p_assert,
42 )
43 from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
pytorch/torch/distributed/fsdp/_common_utils.py:31
29 import torch.nn as nn
30 from torch.distributed._composable_state import _get_module_state, _State
—> 31 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
32 _CHECKPOINT_PREFIX,
33 )
34 from torch.distributed.device_mesh import DeviceMesh
35 from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:9
7 import torch.nn as nn
8 from torch.autograd.graph import save_on_cpu
----> 9 from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
10 from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
12 _CHECKPOINT_WRAPPED_MODULE = “_checkpoint_wrapped_module”
pytorch/torch/distributed/utils.py:267
262 moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
263 return tuple(moved_inputs), tuple(moved_kwargs)
266 def _verify_param_shape_across_processes(
→ 267 process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None
268 ):
269 return dist._verify_params_across_processes(process_group, tensors, logger)
272 def _sync_module_states(
273 module: nn.Module,
274 process_group: dist.ProcessGroup,
(…)
278 broadcast_buffers: bool = True,
279 ) → None:
AttributeError: module ‘torch.distributed’ has no attribute ‘Logger’
--------------- End of Snip -----------------------
please let know if you need any further information to help me.
thanks~