Module 'ray.train' has no attribute 'torch'

I was following the latest howto on plugging ray to existing FashionMNIST example below:-
https://docs.ray.io/en/master/train/examples/pytorch/convert_existing_pytorch_code_to_ray_train.html

I am getting module ‘ray.train’ has no attribute ‘torch’. error on ray version ‘2.10.0’

In [115]: ray.version

Out[115]: ‘2.10.0’

In [116]: torch.version

Out[116]: ‘2.4.0a0+git5e878be’

~ kindly help ~

This seems to be due to lazy imports. Will get this fixed in the example, but in the meantime, can you import ray.train.torch directly?

>>> import ray.train
>>> ray.train.torch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'ray.train' has no attribute 'torch'
>>> import ray.train.torch
>>> ray.train.torch
<module 'ray.train.torch' from '/Users/matt/workspace/ray/python/ray/train/torch/__init__.py'>

I am afraid, it didn’t work for me.

Please find below the snippet of logs:

----------- Begin ----------------------
~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/ray/train/torch/train_loop_utils.py:31

29 FullyShardedDataParallel = None

30 else:

—> 31 from torch.distributed.fsdp import FullyShardedDataParallel

33 try:

34 from torch.profiler import profile

pytorch/torch/distributed/fsdp/init.py:1
----> 1 from ._flat_param import FlatParameter as FlatParameter
2 from .fully_sharded_data_parallel import (
3 BackwardPrefetch,
4 CPUOffload,
(…)
18 StateDictType,
19 )
21 all = [
22 “BackwardPrefetch”,
23 “CPUOffload”,
(…)
37 “StateDictType”,
38 ]

~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/ray/train/torch/init.py:15
13 from ray.train.torch.torch_predictor import TorchPredictor
14 from ray.train.torch.torch_trainer import TorchTrainer
—> 15 from ray.train.torch.train_loop_utils import (
16 accelerate,
17 backward,
18 enable_reproducibility,
19 get_device,
20 get_devices,
21 prepare_data_loader,
22 prepare_model,
23 prepare_optimizer,
24 )
26 all = [
27 “TorchTrainer”,
28 “TorchCheckpoint”,
(…)
39 “TorchDetectionPredictor”,
40 ]

pytorch/torch/distributed/fsdp/_flat_param.py:30

28 import torch.nn.functional as F

29 from torch import Tensor

—> 30 from torch.distributed.fsdp._common_utils import (

31 _FSDPDeviceHandle,

32 _named_parameters_with_duplicates,

33 _no_dispatch_record_stream,

34 _set_fsdp_flattened,

35 HandleTrainingState,

36 )

37 from torch.distributed.utils import (

38 _alloc_storage,

39 _data_ptr_allocated,

40 _free_storage,

41 _p_assert,

42 )

43 from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]

pytorch/torch/distributed/fsdp/_common_utils.py:31

29 import torch.nn as nn

30 from torch.distributed._composable_state import _get_module_state, _State

—> 31 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (

32 _CHECKPOINT_PREFIX,

33 )

34 from torch.distributed.device_mesh import DeviceMesh

35 from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions

pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:9

7 import torch.nn as nn

8 from torch.autograd.graph import save_on_cpu

----> 9 from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs

10 from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint

12 _CHECKPOINT_WRAPPED_MODULE = “_checkpoint_wrapped_module”
pytorch/torch/distributed/utils.py:267

262 moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])

263 return tuple(moved_inputs), tuple(moved_kwargs)

266 def _verify_param_shape_across_processes(

→ 267 process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None

268 ):

269 return dist._verify_params_across_processes(process_group, tensors, logger)

272 def _sync_module_states(

273 module: nn.Module,

274 process_group: dist.ProcessGroup,

(…)

278 broadcast_buffers: bool = True,

279 ) → None:

AttributeError: module ‘torch.distributed’ has no attribute ‘Logger’

--------------- End of Snip -----------------------

please let know if you need any further information to help me.

thanks~

Could you share the script that you’re running?

Timebeing, I am just trying to run it through ipython .

---- Snip------
zaheer$: ipython

Python 3.10.6 (main, Nov 17 2022, 16:52:45) [Clang 14.0.0 (clang-1400.0.29.202)]

Type ‘copyright’, ‘credits’ or ‘license’ for more information

IPython 8.11.0 – An enhanced Interactive Python. Type ‘?’ for help.

In [1]: import ray.train.torch


AttributeError Traceback (most recent call last)
Cell In[1], line 1
----> 1 import ray.train.torch

File ~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/ray/train/torch/init.py:15
13 from ray.train.torch.torch_predictor import TorchPredictor
14 from ray.train.torch.torch_trainer import TorchTrainer
—> 15 from ray.train.torch.train_loop_utils import (
16 accelerate,
17 backward,
18 enable_reproducibility,
19 get_device,
20 get_devices,
21 prepare_data_loader,
22 prepare_model,
23 prepare_optimizer,
24 )
26 all = [
27 “TorchTrainer”,
28 “TorchCheckpoint”,
(…)
39 “TorchDetectionPredictor”,
40 ]
---------- end of snip-----------

How are you installing torch? Does importing from torch directly work for you?

import ray.train.torch
the above command doesn’t work, it gives below error:-
AttributeError: module ‘torch.distributed’ has no attribute ‘Logger’

Regarding PyTorch,
I had installed by compiling pytorch from source(pytorch- 2.4.xx).

Here is the issue i raised in pytorch github.

I would like to get rid of this by uninstalling the compiled PyTorch from source and install using conda to see whether it can be of any help(pytorch 2.2.xx)

Please let me know if you can help me in downgrading from 2.4.xx(uninstall) to 2.2.xx(conda install).

FYI:
i tried conda install, but still it takes compiled version(2.4.xx), lookslike 2.4.xx needs to be removed from the pc.

thanks~

Maybe you can try installing 2.2 in a new conda environment?

with 2.2.x, its working good(2.4.x fails to work), i could get it working only on cpu.

for gpu as “mps” its throwing error, i could share the error file for your later offline debugging.

further idk, how i can upload error txt file here.

thanks ~

1 Like