Using Ray as replacement for Celery (generic task executor)

Hello to all ! I will be presenting my solution to replace Celery with ray workers and some of our code. Comments and discussions are welcome !

Our objective

We have a lot of async (I/O) intensive work but with some CPU usage as welll, mainly :

  • ETL jobs on file : like taking a PDF, OCR, Chunk, Embedded and index, you know… classic
  • User treads or workflow running
  • Ultra HIGHT I/O intensive mainly LLM API calls

Right now we use celery but we have those problems :

  • Not good to handle async I/O jobs
  • Can’t use GPU
  • You can’t have CPU requirements
  • Not great to scale

But we like the queuing with RabbitMQ (persistant and reliable).

Our use of Ray

After some drawback with ray core only ( [CORE] actor atexit is not called when SIGTERM receive · Issue #50004 · ray-project/ray ) where I couln’t make it work with proper cleannup. We went with Serve + core.

So we have our API receive a user file it will save the file to a S3 and send a task wiht the file id to RabbitMQ.

Then we have a pool of worker that will poll the job and process it. We created 2 types of workers :

  • Hight I/O bound workers (for llm calls mainly)
  • Mixed Async and CPU workloads

For example for the mixed workload the worker is something like :


class AbstractWorkerDB(ABC, Generic[T]):
    """
    Abstract base worker class that processes individual tasks
    You need to define the task_class attribute in the concrete implementation
    """

    task_class: Type[T]

    def __init__(self, worker_name: str):
        self.name = worker_name
        self.running = True

        if not jpype.isJVMStarted():
            jpype.startJVM()

        atexit.register(self.__del__)

    async def _async_init(self):
        """Internal async initialization method"""
        self.running = True
        self.logfire = logfire.configure(
            service_name=f"worker_{self.name}",
            scrubbing=logfire.ScrubbingOptions(
                extra_patterns=extra_patterns,
                callback=scrubbing_callback,
            ),
        )
        with self.logfire.span(f"initialize_worker_{self.name}_resources"):

            self._task_adapter = TypeAdapter(self.task_class)

            # Create semaphore to limit the number of concurrent tasks
            self.semaphore = asyncio.Semaphore(RAY_WORKER_CONCURENT_TASKS)

            # RabbitMQ connection for each worker
            self.logfire.info(
                f"Connecting to RabbitMQ at {RAY_RABBITMQ_ADDRESS} and using queue {self.task_class.queue_name()}"
            )
            self.connection = await aio_pika.connect_robust(RAY_RABBITMQ_ADDRESS)
            self.channel = await self.connection.channel()
            await self.channel.set_qos(prefetch_count=RAY_WORKER_CONCURENT_TASKS)
            self.queue = await self.channel.declare_queue(
                str(self.task_class.queue_name()), durable=True
            )

            # Declare dead letter queue for failed tasks
            self.dead_letter_queue = await self.channel.declare_queue(
                FAILED_TASKS_QUEUE_NAME, durable=True
            )

            # setup db acess
            self.db_engine: AsyncEngine = create_async_engine(
                DATABASE_URL,
                pool_size=DB_SESSION_POOL_SIZE,
                max_overflow=DB_SESSION_MAX_OVERFLOW,
                pool_timeout=DB_SESSION_POOL_TIMEOUT,
            )
            self.db_session: async_sessionmaker[AsyncSession] = async_sessionmaker(
                self.db_engine, expire_on_commit=False, class_=AsyncSession
            )
            self.logfire.info("Database pool initialized for worker")

            # TEST
            import jpype

            if not jpype.isJVMStarted():
                jpype.startJVM()

            self.logfire.info(f"Worker {self.name} initialized")

    @abstractmethod
    async def process_task_logic(self, task: T) -> None:
        """
        Abstract method that should be implemented by concrete workers
        to define specific task processing logic
        """
        pass

    async def process_task(self, task: T) -> None:
        """Process a single message with semaphore control and task lifecycle management"""
        task_id = task.task_id
        try:
            async with self.semaphore:
                # Update task status and start time
                task.status = TaskStatus.RUNNING
                task.started_at = datetime.datetime.now()

                # Execute concrete implementation
                await self.process_task_logic(task)

                # Update task completion
                task.status = TaskStatus.COMPLETED
                task.completed_at = datetime.datetime.now()

        except Exception as e:
            error_message = str(e)
            self.logfire.error(
                f"{task_id} : Task failed: {error_message} and send to failed queue",
                exc_info=True,
            )

            task.status = TaskStatus.FAILED
            task.completed_at = datetime.datetime.now()

            failure_data = FailedTask(
                task_id=task_id,
                original_message=task,
                error=error_message,
            )

            await self.channel.default_exchange.publish(
                aio_pika.Message(
                    body=failure_data.model_dump_json().encode(),
                    delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
                ),
                routing_key=FAILED_TASKS_QUEUE_NAME,
            )

    async def run(self):
        """Main worker loop"""
        self.active_tasks = set()
        await self._async_init()
        try:
            async with self.queue.iterator() as queue_iter:
                while True:
                    if not self.running:
                        break
                    try:
                        message = await queue_iter.__anext__()
                    except StopAsyncIteration:
                        break

                    async with message.process():
                        try:
                            task = self._task_adapter.validate_json(
                                message.body.decode()
                            )
                            task_obj = asyncio.create_task(self.process_task(task))
                            self.active_tasks.add(task_obj)
                            task_obj.add_done_callback(self.active_tasks.discard)
                        except Exception as e:
                            self.logfire.error(
                                f"Error processing message: {str(e)}", exc_info=True
                            )
        finally:
            with self.logfire.span(f"worker_{self.name}_run_thread"):
                self.logfire.info(f"Worker {self.name} run thread finished")

    async def shutdown_worker(self):
        """Wait for all tasks to complete and cleanup the worker, you need to call actor.__ray_terminate__.remote()"""
        with self.logfire.span(f"shutdown_worker_{self.name}"):
            self.running = False
            if self.active_tasks:
                self.logfire.info(
                    f"Waiting for {len(self.active_tasks)} tasks to complete"
                )
                try:
                    await asyncio.gather(*self.active_tasks, return_exceptions=True)
                except Exception as e:
                    self.logfire.error(
                        f"Error during task cleanup: {str(e)}", exc_info=True
                    )

            if self.connection:
                await self.connection.close()
                self.logfire.info("RabbitMQ connection disposed for worker")
            if self.db_engine:
                await self.db_engine.dispose()
                self.logfire.info("Database pool disposed for worker")
        return

As you can see we have some particularities :

  • We need to maintain db pool of connection to avoid recreating it each time we process a task
  • logfire for observability that we also need to keep running thought tasks
  • async jobs so we pull the top 30 task and processs them async

this way we have a async and specific worker

Then when we init the app :


async def cleanup():
    logfire.info(f"Cleanup triggered for FastAPI")
    try:
        logfire.info("Starting Ray workers cleanup...")
        # First shutdown all workers in parallel
        logfire.info(f"Shutting down workers : {len(workers)} : {workers}")
        await asyncio.gather(*[worker.shutdown_worker.remote() for worker in workers])  # type: ignore
        for worker in workers:
            worker.__ray_terminate__.remote()  # type: ignore
        logfire.info("All workers cleaned up")

        await cleanup_db()

    except Exception as e:
        logfire.error(f"Error during cleanup: {str(e)}")
        raise


async def init_ray():
    if not ray.is_initialized() and ENVIRONMENT == "prod":
        ray.init(address="auto", namespace="xxx")
    elif not ray.is_initialized() and ENVIRONMENT == "local":
        ray.init(address="auto", ignore_reinit_error=True, namespace="xxx")


async def init_workers():
    global workers
    # Get remaining available CPUs in the cluster (not used by Serve)
    available_cpus = int(
        ray.available_resources().get("CPU", 1)
    )  # Default to 1 if not found
    logfire.info(f"Available (unused) cluster CPUs: {available_cpus}")
    # Use available CPUs, leaving some headroom for other operations
    num_workers = max(1, available_cpus - 1)
    workers = [TestWorker.remote(f"worker_{i}") for i in range(num_workers)]
    logfire.info(f"Initialized workers: {len(workers)}")
    for worker in workers:
        worker.run.remote()  # type: ignore


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Rest of startup code
    await init_db()
    setup_retriever()
    await init_ray()
    await initialize_task_manager()
    await init_workers()

    try:
        yield
    finally:
        # Always perform cleanup when the app is shutting down
        await cleanup()


app = FastAPI(lifespan=lifespan)


# Ray Serve deployment
@serve.deployment(
    ray_actor_options={"num_cpus": 1},
    num_replicas=1,
    health_check_period_s=10,
    health_check_timeout_s=10,
    max_ongoing_requests=1000,
    max_queued_requests=-1,
    graceful_shutdown_timeout_s=240,
    name="xxxx",
    version="v1",
)
@serve.ingress(app)
class xxxxxApp:

    def __init__(self):
        print("Callisto backend is starting...")
        logfire.configure(
            service_name="callisto",
            scrubbing=logfire.ScrubbingOptions(
                extra_patterns=extra_patterns,
                callback=scrubbing_callback,
            ),
        )
        logfire.instrument_fastapi(app)
        app.include_router(table_query_template_router)

        if ENVIRONMENT == "local":
            app.include_router(admin_router)

        # Add any necessary middleware
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
        logfire.info("Callisto Backend initialized !!")

    def __del__(self):
        print("Callisto backend deleted")

    @app.get("/task/{nb_task}")
    async def test(self, nb_task: int):
        logfire.info(f"Test task {nb_task}")
        task = TestTask()
        task_manager = get_task_manager()
        for i in range(nb_task):
            task_id = await task_manager.submit_task(task)
            logfire.info(f"Task {i} submitted with ID: {task_id}")
        return {"task_id": task_id}

    @app.get("/health")
    async def health_check(self):
        return {"status": "ok !"}


xxxx = xxxxApp.bind()

This way we init the workers and the API and all is in the Ray cluster. And then as we use RabbitMQ task are distributed to the first available worker thus maximizing load on the cluster.
Using serve make it way better to handle destroying workers and closing fastAPI properly.

Only drawback is the fast reload but thanks to @ volks73 we had something working : [Serve] Reloading not working with cluster · Issue #40553 · ray-project/ray

Next steps is maybe add task persistance in db with state if they crashed or retries…

Some questions ?

  • As you can see we need some global vars, workers or task_manager but is it OK to have such global var in the API replica ? We also have those global vars for the db then we use db: AsyncSession = Depends(get_db) in the fastAPI route function.
  • Is it a good way of using Ray workers, are we going to have so issues down the line ?
  • Do you see things we could do better ?

Thanks to all !

Hello! Thank you for joining the Ray community Charles and sharing your use case with us :smiley:

I’ve done some researched in the docs and I think I have the answers to your questions!

  1. As you can see we need some global vars, workers or task_manager but is it OK to have such global var in the API replica ?

Using global variables like task_manager or workers can work in simple setups, but as your system scales, it can become tricky. Globals can cause issues with state management, concurrency, and fault tolerance. Instead:

Your FastAPI setup seems correct though so I don’t think you necessarily need to change it.

  1. Is it a good way of using Ray workers, are we going to have so issues down the line ?

It seems like a good use case for Ray workers! Don’t really see any issues here :slight_smile: Just make sure to stay on top of resource management.

  1. Do you see things we could do better?

Try out some of the Ray debugging tools to help debug any potential issues that may arise. There’s also some nice user guides out there that might be helpful. I’ll link them here!