Ray Train v1.9.1: returns an AttributeError: module 'ray.train' has no attribute 'torch'

Hi,

I’m trying to data-parallelize this code.
I’m launching my training function with a

trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
trainer.run(train_func)
trainer.shutdown()

I’m getting this error:

Traceback (most recent call last):
  File "a2d2_code/train_ray.py", line 116, in <module>
    train_loader = train.torch.prepare_data_loader(train_loader)
AttributeError: module 'ray.train' has no attribute 'torch'

Ray is v1.9.1
Torch is v1.8.1+cu111

I’m surprised by this error, since the doc says to use this snippet for torch:

Hey @Lacruche,

By default, ray.train will not import the torch module since it requires additional Torch dependencies that are only needed if the user is using Ray Train with Torch. You should be able to fix this by explicitly importing it:

import ray.train.torch

Secondly, are you by any chance calling train.torch outside of train_func? There’s some extra logic so specifying backend="torch" will trigger the above import, so you won’t need to call it yourself in the training function. Also, prepare_data_loader will raise an exception if called outside of the training function (even when it is imported properly :slight_smile:)!