Proper cleanup of Ray long running actor (celery like setup)

Hello to all the Ray team ! First you rock guy, incredible piece of software here !

I am migrating from FastAPI + RabbitMQ + Celery to a FastAPI + rabbitMQ + Ray. Almost there but I had some drawback in the end, and I hope you can help me out, I think it can be super beneficial for others that want to use the power of ray for more general applications (outside of ML)!

My setup

I have a fastAPI service that can send thousands of task (mainly files to ETL, embedded and index). Those task should be persisted thus I went for a RabbitMQ.
So
FastAPI → RabbitMQ → Workers

Workers should then process task with a really Hight async… I created this worker on ray :slight_smile:

T = TypeVar("T", bound=Task)

class AbstractWorkerDB(ABC, Generic[T]):
    """
    Abstract base worker class that processes individual tasks
    You need to define the task_class attribute in the concrete implementation
    """

    task_class: Type[T]

    def __init__(self, worker_name: str):
        self.name = worker_name
        self.running = True

        if not jpype.isJVMStarted():
            jpype.startJVM()

        atexit.register(self.__del__)

    async def _async_init(self):
        """Internal async initialization method"""
        self.running = True
        self.logfire = logfire.configure(
            service_name=f"worker_{self.name}",
            scrubbing=logfire.ScrubbingOptions(
                extra_patterns=extra_patterns,
                callback=scrubbing_callback,
            ),
        )
        with self.logfire.span(f"initialize_worker_{self.name}_resources"):

            self._task_adapter = TypeAdapter(self.task_class)

            # Create semaphore to limit the number of concurrent tasks
            self.semaphore = asyncio.Semaphore(RAY_WORKER_CONCURENT_TASKS)

            # RabbitMQ connection for each worker
            self.logfire.info(
                f"Connecting to RabbitMQ at {RAY_RABBITMQ_ADDRESS} and using queue {self.task_class.queue_name()}"
            )
            self.connection = await aio_pika.connect_robust(RAY_RABBITMQ_ADDRESS)
            self.channel = await self.connection.channel()
            await self.channel.set_qos(prefetch_count=RAY_WORKER_CONCURENT_TASKS)
            self.queue = await self.channel.declare_queue(
                str(self.task_class.queue_name()), durable=True
            )

            # Declare dead letter queue for failed tasks
            self.dead_letter_queue = await self.channel.declare_queue(
                FAILED_TASKS_QUEUE_NAME, durable=True
            )

            # setup db acess
            self.db_engine: AsyncEngine = create_async_engine(
                DATABASE_URL,
                pool_size=DB_SESSION_POOL_SIZE,
                max_overflow=DB_SESSION_MAX_OVERFLOW,
                pool_timeout=DB_SESSION_POOL_TIMEOUT,
            )
            self.db_session: async_sessionmaker[AsyncSession] = async_sessionmaker(
                self.db_engine, expire_on_commit=False, class_=AsyncSession
            )
            self.logfire.info("Database pool initialized for worker")

            # TEST
            import jpype

            if not jpype.isJVMStarted():
                jpype.startJVM()

            self.logfire.info(f"Worker {self.name} initialized")

    @abstractmethod
    async def process_task_logic(self, task: T) -> None:
        """
        Abstract method that should be implemented by concrete workers
        to define specific task processing logic
        """
        pass

    async def process_task(self, task: T) -> None:
        """Process a single message with semaphore control and task lifecycle management"""
        task_id = task.task_id
        try:
            async with self.semaphore:
                # Update task status and start time
                task.status = TaskStatus.RUNNING
                task.started_at = datetime.datetime.now()

                # Execute concrete implementation
                await self.process_task_logic(task)

                # Update task completion
                task.status = TaskStatus.COMPLETED
                task.completed_at = datetime.datetime.now()

        except Exception as e:
            error_message = str(e)
            self.logfire.error(
                f"{task_id} : Task failed: {error_message} and send to failed queue",
                exc_info=True,
            )

            task.status = TaskStatus.FAILED
            task.completed_at = datetime.datetime.now()

            failure_data = FailedTask(
                task_id=task_id,
                original_message=task,
                error=error_message,
            )

            await self.channel.default_exchange.publish(
                aio_pika.Message(
                    body=failure_data.model_dump_json().encode(),
                    delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
                ),
                routing_key=FAILED_TASKS_QUEUE_NAME,
            )

    async def run(self):
        """Main worker loop"""
        self.active_tasks = set()
        await self._async_init()
        try:
            async with self.queue.iterator() as queue_iter:
                while True:
                    if not self.running:
                        break
                    try:
                        message = await queue_iter.__anext__()
                    except StopAsyncIteration:
                        break

                    async with message.process():
                        try:
                            task = self._task_adapter.validate_json(
                                message.body.decode()
                            )
                            task_obj = asyncio.create_task(self.process_task(task))
                            self.active_tasks.add(task_obj)
                            task_obj.add_done_callback(self.active_tasks.discard)
                        except Exception as e:
                            self.logfire.error(
                                f"Error processing message: {str(e)}", exc_info=True
                            )
        finally:
            with self.logfire.span(f"worker_{self.name}_run_thread"):
                self.logfire.info(f"Worker {self.name} run thread finished")

    async def shutdown_worker(self):
        """Wait for all tasks to complete and cleanup the worker, you need to call actor.__ray_terminate__.remote()"""
        print("shutdown_worker")
        with self.logfire.span(f"shutdown_worker_{self.name}"):
            self.running = False
            if self.active_tasks:
                self.logfire.info(
                    f"Waiting for {len(self.active_tasks)} tasks to complete"
                )
                try:
                    await asyncio.gather(*self.active_tasks, return_exceptions=True)
                except Exception as e:
                    self.logfire.error(
                        f"Error during task cleanup: {str(e)}", exc_info=True
                    )

            if self.connection:
                await self.connection.close()
                self.logfire.info("RabbitMQ connection disposed for worker")
            if self.db_engine:
                await self.db_engine.dispose()
                self.logfire.info("Database pool disposed for worker")
        return

    def __del__(self):
        if jpype.isJVMStarted():
            jpype.shutdownJVM()
            self.logfire.info("JVM disposed for worker")
        self.logfire.info(f"Worker {self.name} shutdown")

This worker can process up to 30 task in async and poll from the rabbitMQ task queue. Most important : It create a DB pool asyncpg and cleanup those connections in the end. This is why I went with all this, I don’t want to recreate a connection to the DB each time I send a task to the ray cluster (or a logfire connection).

Then I needed a way to populate the cluster with those workers,
this is my fastAPI app :


async def cleanup():
    logfire.info(f"Cleanup triggered for FastAPI")
    logfire.info("Starting Ray workers cleanup...")
    try:
        # First shutdown all workers in parallel
        logfire.info(f"Shutting down workers : {len(workers)} : {workers}")
        for worker in workers:
            await worker.shutdown_worker.remote()  # type: ignore
        for worker in workers:
            worker.__ray_terminate__.remote()  # type: ignore
        logfire.info("All workers cleaned up")

        await cleanup_db()

        # Move Ray shutdown to the main thread using a synchronous call
        if ray.is_initialized():
            logfire.info("Shutting down Ray...")
            ray.shutdown()
            logfire.info("Ray shutdown complete")

        # Shutdown event
        if jpype.isJVMStarted():
            jpype.shutdownJVM()

    except Exception as e:
        logfire.error(f"Error during cleanup: {str(e)}")
        raise

@asynccontextmanager
async def lifespan(app: FastAPI):

    # Start the shutdown monitor in a separate thread
    shutdown_thread = threading.Thread(target=listen_to_shutdown_pipe, daemon=True)
    shutdown_thread.start()

    # Rest of startup code
    await init_db()
    setup_retriever()
    await init_ray()

    # create workers
    global workers
    workers = [TestWorker.remote(f"worker_{i}") for i in range(2)]
    for worker in workers:  
        worker.run.remote()  # type: ignore

    try:
        yield
    finally:
        await cleanup()

So far so good !

The problems

For now I am using docker to deploy on the same container the Ray cluster and uvicorn workers for the fastAPI.

The issue is : When I crt+c or docker compose down or reload the uvircorn it sends a SIGTERM to the uvicorn master process and the ray worker on the same CPU core catch it, thus I have this :

backend-dev-1           | *** SIGTERM received at time=1737543906 on cpu 1 ***
backend-dev-1           | PC: @     0x7f043882aea6  (unknown)  epoll_wait
backend-dev-1           |     @     0x7f043875e050  (unknown)  (unknown)
backend-dev-1           | [2025-01-22 11:05:06,608 E 2484 2484] logging.cc:447: *** SIGTERM received at time=1737543906 on cpu 1 ***
backend-dev-1           | [2025-01-22 11:05:06,608 E 2484 2484] logging.cc:447: PC: @     0x7f043882aea6  (unknown)  epoll_wait
backend-dev-1           | [2025-01-22 11:05:06,608 E 2484 2484] logging.cc:447:     @     0x7f043875e050  (unknown)  (unknown)
backend-dev-1           | 11:05:06.612 Cleanup triggered for FastAPI

Killing my loop and when the fastapi tries to cleannup the worker it crashes because the worker is already killed.

I also tried to put an atexit hook on the ray actor to launch cleanup but this does not work as the SIGTERM is on the underlaying ray worker so it never reaches the actor…

How you can help me

3 questions :

  • Am I doing this the right way ? Is there an easier way of doing what I want (pool of very long running workers polling from rabbitmq)
  • Is there a way of properly cleaning the app
    • Making ray not catch the SIGTERM ? And letting uvicorn properly shutdown ray?
    • Making Ray properly call the atexit async when the underlaying worker SIGTERM
    • Running fastAPI on ray and when the SIGTERM is catch properly use the cleannup un the lifespan function
  • I also looked at ray serve with FastAPI but I think this will lead to the same problem, when exiting the container or sending a SIGTERM the app won’t be properly cleaned (my workers won’t be called on shutdown an thus the connections won’t be cleanned…)

Thank you very much for your help !!
I am sure this setup could be of help to others that are also interesting to go from celery to ray ! The final setup will be open sourced so others can use it :slight_smile: