Ray is not compatible with numba cuda when import

The same issue is reported here:

I want to combine Ray and numba cuda for parallel computing. However, Ray cannot work properly with numba.

If ray and numba kernel are in a same notebook, it is working properly, i.e.,

import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os

ray.shutdown()
ray.init(ignore_reinit_error=True,num_gpus=5)

import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os
import cupy as cp

# cuda kernel
@ray.remote(num_gpus=1)
def f1():

    @cuda.jit
    def compute_pi():
        """Find the maximum value in values and store in result[0]"""
        thread_id = cuda.grid(1)

        # Compute pi by drawing random (x, y) points and finding what
        # fraction lie inside a unit circle
        inside = 0
        for i in range(1000000):
            x = 1
            y = 1

    threads_per_block = 64
    blocks = 24
    rng_states = cuda.device_array(threads_per_block * blocks)

    compute_pi[blocks, threads_per_block]()

    
    return rng_states.copy_to_host()

# cupy as a comparison
@ray.remote(num_gpus=1)
def ff():
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]="4"
    
    x_on_gpu1 = cp.array([1, 2, 3, 4, 5])
    
    print('hi')
    return x_on_gpu1

print(ray.get(ff.remote()))

print(ray.get(f1.remote()))

However, if I store f1 and ff as a f.py file

# The following is in a f.py file
import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os
import cupy as cp

@ray.remote(num_gpus=1)
def f1():

    @cuda.jit
    def compute_pi():
        """Find the maximum value in values and store in result[0]"""
        thread_id = cuda.grid(1)

        # Compute pi by drawing random (x, y) points and finding what
        # fraction lie inside a unit circle
        inside = 0
        for i in range(1000000):
            x = 1
            y = 1

    threads_per_block = 64
    blocks = 24
    rng_states = cuda.device_array(threads_per_block * blocks)

    compute_pi[blocks, threads_per_block]()

    
    return rng_states.copy_to_host()

@ray.remote(num_gpus=1)
def ff():
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]="4"
    
    x_on_gpu1 = cp.array([1, 2, 3, 4, 5])
    
    print('hi')
    return x_on_gpu1

then the code cannot work:

import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os

ray.shutdown()
ray.init(ignore_reinit_error=True,num_gpus=5)

import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os
import cupy as cp

from f import ff,f1


print(ray.get(ff.remote()))

print(ray.get(f1.remote()))

the first ff.remote is working and the second gives:

RayTaskError(ModuleNotFoundError)         Traceback (most recent call last)
<ipython-input-7-d13267e4beec> in <module>
     22 print(ray.get(ff.remote()))
     23 
---> 24 print(ray.get(f1.remote()))

~/anaconda3/envs/crbmg/lib/python3.8/site-packages/ray/_private/client_mode_hook.py in wrapper(*args, **kwargs)
     45         if client_mode_enabled and _client_hook_enabled:
     46             return getattr(ray, func.__name__)(*args, **kwargs)
---> 47         return func(*args, **kwargs)
     48 
     49     return wrapper

~/anaconda3/envs/crbmg/lib/python3.8/site-packages/ray/worker.py in get(object_refs, timeout)
   1454                     worker.core_worker.dump_object_store_memory_usage()
   1455                 if isinstance(value, RayTaskError):
-> 1456                     raise value.as_instanceof_cause()
   1457                 else:
   1458                     raise value

RayTaskError(ModuleNotFoundError): ray::f1() (pid=3606402, ip=210.45.78.32)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/weakref.py", line 131, in __getitem__
    o = self.data[key]()
KeyError: '_ZN08NumbaEnv1f2f112$3clocals$3e14compute_pi$241E'

During handling of the above exception, another exception occurred:

ray::f1() (pid=3606402, ip=210.45.78.32)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/funcdesc.py", line 93, in lookup_module
    return sys.modules[self.modname]
KeyError: 'f'

During handling of the above exception, another exception occurred:

ray::f1() (pid=3606402, ip=210.45.78.32)
  File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "/data2/zhangjuenjie/zjj/CRBMG/f.py", line 28, in f1
    compute_pi[blocks, threads_per_block]()
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/cuda/compiler.py", line 772, in __call__
    return self.dispatcher.call(args, self.griddim, self.blockdim,
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/cuda/compiler.py", line 878, in call
    kernel = self.compile(argtypes)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/cuda/compiler.py", line 947, in compile
    kernel = compile_kernel(self.py_func, argtypes,
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/cuda/compiler.py", line 57, in compile_kernel
    cres = compile_cuda(pyfunc, types.void, args, debug=debug, inline=inline)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/cuda/compiler.py", line 40, in compile_cuda
    cres = compiler.compile_extra(typingctx=typingctx,
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler.py", line 602, in compile_extra
    return pipeline.compile_extra(func)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler.py", line 352, in compile_extra
    return self._compile_bytecode()
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler.py", line 414, in _compile_bytecode
    return self._compile_core()
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler.py", line 394, in _compile_core
    raise e
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler.py", line 385, in _compile_core
    pm.run(self.state)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 339, in run
    raise patched_exception
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 330, in run
    self._runPass(idx, pass_inst, state)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 289, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 262, in check
    mangled = func(compiler_state)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/typed_passes.py", line 449, in run_pass
    NativeLowering().run_pass(state)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/typed_passes.py", line 371, in run_pass
    lower = lowering.Lower(targetctx, library, fndesc, interp,
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/lowering.py", line 37, in __init__
    self.env = Environment.from_fndesc(self.fndesc)
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/environment.py", line 23, in from_fndesc
    inst = cls(fndesc.lookup_globals())
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/funcdesc.py", line 81, in lookup_globals
    return self.global_dict or self.lookup_module().__dict__
  File "/data2/zhangjuenjie/anaconda3/envs/crbmg/lib/python3.8/site-packages/numba/core/funcdesc.py", line 95, in lookup_module
    raise ModuleNotFoundError(
ModuleNotFoundError: can't compile f1.<locals>.compute_pi: import of module f failed

This is very strange since every thing is working but fails when combine them together.

Here is the shortest reproducible code:

# f.py
import ray
import numba

@ray.remote
def foo():
    @numba.jit
    def bar():
        return 1

    bar()
# driver script
import ray
from f import foo

if __name__ == '__main__':
    ray.init()
    ref = foo.remote()
    ray.get(ref)

error messages:

Traceback (most recent call last):
  File "/Users/siyuan/Code/ray/python/aaa.py", line 7, in <module>
    ray.get(ref)
  File "/Users/siyuan/Code/ray/python/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/Users/siyuan/Code/ray/python/ray/worker.py", line 1440, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ModuleNotFoundError): ray::foo() (pid=98015, ip=10.0.0.109)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/weakref.py", line 137, in __getitem__
    o = self.data[key]()
KeyError: '_ZN08NumbaEnv1f3foo12$3clocals$3e7bar$241E'

During handling of the above exception, another exception occurred:

ray::foo() (pid=98015, ip=10.0.0.109)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/funcdesc.py", line 93, in lookup_module
    return sys.modules[self.modname]
KeyError: 'f'

During handling of the above exception, another exception occurred:

ray::foo() (pid=98015, ip=10.0.0.109)
  File "python/ray/_raylet.pyx", line 484, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 491, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
  File "/Users/siyuan/Code/ray/python/f.py", line 11, in foo
    bar()
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 433, in _compile_for_args
    raise e
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 366, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 857, in compile
    cres = self._compiler.compile(args, return_type)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 77, in compile
    status, retval = self._compile_cached(args, return_type)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 91, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/dispatcher.py", line 109, in _compile_core
    pipeline_class=self.pipeline_class)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler.py", line 602, in compile_extra
    return pipeline.compile_extra(func)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler.py", line 352, in compile_extra
    return self._compile_bytecode()
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler.py", line 414, in _compile_bytecode
    return self._compile_core()
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler.py", line 394, in _compile_core
    raise e
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler.py", line 385, in _compile_core
    pm.run(self.state)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_machinery.py", line 339, in run
    raise patched_exception
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_machinery.py", line 330, in run
    self._runPass(idx, pass_inst, state)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_machinery.py", line 289, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/compiler_machinery.py", line 262, in check
    mangled = func(compiler_state)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/object_mode_passes.py", line 120, in run_pass
    lowered = backend_object_mode()
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/object_mode_passes.py", line 118, in backend_object_mode
    state.flags)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/object_mode_passes.py", line 77, in _py_lowering_stage
    lower = pylowering.PyLower(targetctx, library, fndesc, interp)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/lowering.py", line 37, in __init__
    self.env = Environment.from_fndesc(self.fndesc)
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/environment.py", line 23, in from_fndesc
    inst = cls(fndesc.lookup_globals())
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/funcdesc.py", line 81, in lookup_globals
    return self.global_dict or self.lookup_module().__dict__
  File "/Users/siyuan/.pyenv/versions/3.6.9/lib/python3.6/site-packages/numba/core/funcdesc.py", line 96, in lookup_module
    f"can't compile {self.qualname}: "
ModuleNotFoundError: can't compile foo.<locals>.bar: import of module f failed
(pid=98015) /Users/siyuan/Code/ray/python/f.py:7: NumbaWarning: 
(pid=98015) Compilation is falling back to object mode WITH looplifting enabled because Function bar failed at nopython mode lowering due to: can't compile foo.<locals>.bar: import of module f failed
(pid=98015)   @numba.jit

This is likely because when numba executes bar, numba assumes that bar can always be found in f.py related to the driver script (e.g. python can import f in driver script). However, f.py is executed remotely by a worker process, not by the driver script process, so numba somehow failed to import f, thus failed to execute the compiled bar() function.

To comfort numba, here is a workaround:

# f.py
import ray
import numba

import sys
module_name = globals()['__name__']
current_module = sys.modules[module_name]

@ray.remote
def foo():
    # tell numba that it is actually running in `f.py`
    sys.modules[module_name] = current_module
    @numba.jit
    def bar():
        return 1

    bar()

With this workaround I can run the script successfully.

Here is a cleaner fix:

# f.py
import ray
import numba

def bar():
    return 1

@ray.remote
def foo():
    _bar = numba.jit(bar)
    _bar()

so you can write it like

# The following is in a f.py file
import ray
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np
import math
import os
import cupy as cp


def compute_pi():
    """Find the maximum value in values and store in result[0]"""
    thread_id = cuda.grid(1)

    # Compute pi by drawing random (x, y) points and finding what
    # fraction lie inside a unit circle
    inside = 0
    for i in range(1000000):
        x = 1
        y = 1

@ray.remote(num_gpus=1)
def f1():
    threads_per_block = 64
    blocks = 24
    rng_states = cuda.device_array(threads_per_block * blocks)
    _compute_pi = cuda.jit(compute_pi)
    _compute_pi[blocks, threads_per_block]()
    return rng_states.copy_to_host()

@ray.remote(num_gpus=1)
def ff():
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]="4"
    
    x_on_gpu1 = cp.array([1, 2, 3, 4, 5])
    
    print('hi')
    return x_on_gpu1
2 Likes

thanks a lot, it works.