Ray train examples are broken

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I am trying to run torch_quick_start.py on gpu machines (4 gpus).

I have changed the trainer code by uncommenting this line

trainer = Trainer(backend="torch", num_workers=4, use_gpu=True)

However it gives me below error.

   return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/ray/train/backend.py", line 498, in end_training
    output = session.finish()
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/ray/train/session.py", line 102, in finish
    func_output = self.training_thread.join()
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/ray/train/utils.py", line 94, in join
    raise self.exc
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/ray/train/utils.py", line 87, in run
    self.ret = self._target(*self._args, **self._kwargs)
  File "pt.py", line 62, in train_func_distributed
    output = model(input)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "pt.py", line 21, in forward
    return self.layer2(self.relu(self.layer1(input)))
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_addmm)```

I am using latest pytorch 1.11 with cuda-11

Hey @goswamig, thanks for reporting this! Seems like we’re missing the step to move the input data to the GPU device. I can make a fix for this.