Handle "Cuda out of memory" exception on ray serve replica

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I’m serving an AI model on a small ray cluster which receives images by arbitrary sizes. Because of speed, it should try to the inference on the GPU but if the image is too large it has to use the cpu. Usually I catch a “cuda out of memory” exception with a simple try/except in python but that doesn’t seem to work with a ray serve replica serving via http request. Once the Cuda OOM exception is thrown it doesn’t continue but stops the task with the following output:

replica.py:510 - HANDLE __call__ ERROR 2384.0ms

future: <Task finished coro=<_wrap_awaitable() done, defined at C:\Python37\lib\asyncio\tasks.py:623> exception=RayTaskError(RuntimeError)(RuntimeError...
RuntimeError: CUDA out of memory. Tried to allocate 116.00 MiB (GPU

Hi @bananajoe182, sorry you’re running into this. Could you share more details about the try/except code (what does it look like, and where are you calling it in the case where it works, and in the case where it doesn’t work?). It’s interesting that it works for you in ordinary Python but doesn’t work in a Ray Serve replica.

Hey @architkulkarni ,
so this code handles the runtime exception ‘Cuda out of memory…’ as it should (continue with cpu) when run locally but somehow aborts the whole process when the exception is thrown in a serve replica.

import numpy as np
import ray
import utils
import torch
from starlette.requests import Request

from ray import serve
from ray.exceptions import RayTaskError

if torch.cuda.is_available():
    device = torch.device("cuda")
    device = torch.device("cpu")

class Lama:
    def __init__(self):
        model_path = 'big-lama.pt'
        self.model = torch.jit.load(model_path)

    async def __call__(self, http_request: Request) -> bytes:
        form = await http_request.form()
        image_stream: bytes = await form['file'].read()
        info_stream: bytes = await form['info'].read()

        img = np.fromstring(image_stream, dtype='<f4')

        info = np.fromstring(info_stream, count=3, dtype=np.intc)
        width, height, channels = info[0], info[1], info[2]

        img = np.reshape(img, (channels, height, width))
        img = np.transpose(img, (1, 2, 0))
        img = np.flipud(img)

        rgb = img[:, :, :3].transpose(2, 0, 1)
        mask = img[np.newaxis, :, :, 3]

        rgb = utils.pad_img_to_modulo(rgb, 8)
        mask = utils.pad_img_to_modulo(mask, 8)

        rgb_tensor = torch.from_numpy(rgb).to(device).unsqueeze(0)
        mask_tensor = torch.from_numpy(mask).to(device).unsqueeze(0)
        mask_tensor = (mask_tensor > 0) * 1

        # DO image processing here
        result = self.infer(rgb_tensor, mask_tensor)
        img = result[0, ::, :height, :width].permute(1, 2, 0).detach().cpu().numpy()
        img = np.flipud(img)
        img = np.transpose(img, (2, 0, 1))
        output_file = img.tobytes()

        rgb_tensor = None
        mask_tensor = None
        result = None

        return output_file

    def infer(self, rgb_tensor, mask_tensor):
            result = self.model(rgb_tensor, mask_tensor)
        except RuntimeError as e:
            if str(e).startswith('CUDA out of memory.'):
                print('Using cpu...')
                rgb_tensor = rgb_tensor.to(torch.device('cpu'))
                mask_tensor = mask_tensor.to(torch.device('cpu'))
                result = self.model(rgb_tensor, mask_tensor)
                raise e

        return result

lama = Lama.bind()

Thanks for sharing the error, that does seem pretty bizarre. You’re catching the exception within the infer function, so I don’t see how an exception could be raised. Is there more traceback that could show us where exactly the exception is being raised from? (Perhaps in somewhere in the actor logs?)

I also wonder if you’re able to reproduce this with a very minimal example to take CUDA out of the picture (e.g. a Serve replica that does nothing except manually raise an exception inside __call__ and try-catch it)