I have a remote worker that invokes a parallelized numba function, like:
import numba
import time
import ray
@numa.njit(parallel=True)
def nba_sum(x):
r = 0
for i in numba.prange(x.shape[0]):
r += x[i]
return r
@ray.remote
def sum(shared_token):
x = <get the array view from shared token>
nba_sum(x) # so i ensure the numba function is already compiled
t = time.time()
nba_sum(x)
return time.time() - t
x = <some very large array view to shared memory>
shared_token = <shared token to the memory of x>
t = time.time()
nba_sum(x)
print('local time =', time.time() - t)
print('remote time =', ray.get(sum.remote(shared_token))
I used a machine with 256 logical cpus (two AMD EPYC 7763 64-core Processors) to test the script and get that the output remote time always slower than the local one, about > 20% more time.
Can I avoid the cost?