RaySGD PyTorch fail: "TypeError: can't pickle SSLContext objects"

Hi,

I have a simple PyTorch for loop working fine in a single GPU of a ml.g4dn.12xlarge SageMaker Notebook instance (4x NVIDIA T4)

I’m following that Ray tutorial to scale it over 4 GPUs:

from ray.util.sgd.v2 import Trainer

trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
results = trainer.run(train_function)
trainer.shutdown()

I see 4 processes starting but get pretty quickly this error: “TypeError: can’t pickle SSLContext objects”

Traceback (most recent call last):
  File "train.py", line 289, in <module>
    results = trainer.run(train_single)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/trainer.py", line 238, in run
    run_dir=self.latest_run_dir,
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/trainer.py", line 511, in __init__
    train_func, checkpoint, checkpoint_strategy, run_dir=run_dir)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/trainer.py", line 526, in _start_training
    lambda: self._executor.start_training(
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/trainer.py", line 537, in _run_with_error_handling
    return func()
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/trainer.py", line 531, in <lambda>
    latest_checkpoint_id=latest_checkpoint_id
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/backends/backend.py", line 435, in start_training
    checkpoint=checkpoint_dict))
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/sgd/v2/worker_group.py", line 267, in execute_single_async
    func, *args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/actor.py", line 118, in remote
    return self._remote(args, kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/util/tracing/tracing_helper.py", line 408, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/actor.py", line 160, in _remote
    return invocation(args, kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/actor.py", line 154, in invocation
    num_returns=num_returns)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/actor.py", line 916, in _actor_method_call
    list_args, name, num_returns, self._ray_actor_method_cpus)
  File "python/ray/_raylet.pyx", line 1525, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 1530, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 351, in ray._raylet.prepare_args
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/serialization.py", line 348, in serialize
    return self._serialize_to_msgpack(value)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/serialization.py", line 328, in _serialize_to_msgpack
    self._serialize_to_pickle5(metadata, python_objects)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/serialization.py", line 288, in _serialize_to_pickle5
    raise e
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/serialization.py", line 285, in _serialize_to_pickle5
    value, protocol=5, buffer_callback=writer.buffer_callback)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 580, in dump
    return Pickler.dump(self, obj)
TypeError: can't pickle SSLContext objects

Sounds familiar to anyone? How to use Ray to distribute PyTorch training code that works well on one card?