Hi,
I’m trying to data-parallelize this code.
I’m launching my training function with a
trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
I’m getting this error:
Traceback (most recent call last):
File "a2d2_code/train_ray.py", line 116, in <module>
train_loader = train.torch.prepare_data_loader(train_loader)
AttributeError: module 'ray.train' has no attribute 'torch'
Ray is v1.9.1
Torch is v1.8.1+cu111
I’m surprised by this error, since the doc says to use this snippet for torch:
Hey @Lacruche,
By default, ray.train
will not import the torch
module since it requires additional Torch dependencies that are only needed if the user is using Ray Train with Torch. You should be able to fix this by explicitly importing it:
import ray.train.torch
Secondly, are you by any chance calling train.torch
outside of train_func
? There’s some extra logic so specifying backend="torch"
will trigger the above import, so you won’t need to call it yourself in the training function. Also, prepare_data_loader
will raise an exception if called outside of the training function (even when it is imported properly
)!