However, when I try this method out, the data in the worker_func remains an ObjectRef and throws errors when I try to process it. Here is a minimal example of what I hope to get work
import ray
import ray.util.multiprocessing as mp
import torch
@ray.remote
def f(args_):
idx, A_= args_
return A_[idx]
ray.init()
A = torch.rand(5, 5)
A_ = ray.put(A)
with mp.Pool() as p:
args_ = zip(range(5), [A_] *5)
results = p.map(f.remote, args_)
print(results = ray.get(results))
TypeError: 'ray._raylet.ObjectRef' object is not subscriptable
Very new to ray, any pointers are appreciated, thanks!
Hey @Qian_Huang, that error means that A_ is an ObjectRef. This happens because Ray will not dereference refs automatically unless they are passed as top-level arguments. But you can de-reference it by using ray.get(A_):
import ray
import ray.util.multiprocessing as mp
import torch
@ray.remote
def f(args_):
idx, A_= args_
A_ = ray.get(A_) # <---- added ray.get()
return A_[idx]
A = torch.rand(5, 5)
A_ = ray.put(A)
with mp.Pool() as p:
args_ = zip(range(5), [A_] *5)
results = p.map(f.remote, args_)
print(ray.get(results))
Another thing I am trying to figure out is how to use tqdm with this to show progress? For normal pool I can use imap, but seems imap behaves differently in ray: ray seems just sends the entire args_ to f
import ray
import ray.util.multiprocessing as mp
import torch
from tqdm import tqdm
import ray
import ray.util.multiprocessing as mp
import torch
import time
def f(args):
idx, A_ = args
time.sleep(idx)
return ray.get(A_)[idx]
ray.init()
A = torch.rand(5, 5)
A_ = ray.put(A)
pool = mp.Pool()
args_ = zip(range(5), [A_] *5)
for result in tqdm(pool.imap(f, args_), total=5):
print(result)
Huh, that’s a strange. Possibly a bug in ray.multiprocessing (cc @eoakes ). As a workaround, you could avoid passing A_ as an arg and closure capture it as a global variable:
from tqdm import tqdm
import ray
import ray.util.multiprocessing as mp
import torch
import time
ray.init()
A = torch.rand(5, 5)
A_ = ray.put(A)
def f(idx):
time.sleep(idx)
return ray.get(A_)[idx]
pool = mp.Pool()
args_ = range(5)
for result in tqdm(pool.imap(f, args_), total=5):
print(result)