Hello to all ! I will be presenting my solution to replace Celery with ray workers and some of our code. Comments and discussions are welcome !
Our objective
We have a lot of async (I/O) intensive work but with some CPU usage as welll, mainly :
- ETL jobs on file : like taking a PDF, OCR, Chunk, Embedded and index, you know… classic
- User treads or workflow running
- Ultra HIGHT I/O intensive mainly LLM API calls
Right now we use celery but we have those problems :
- Not good to handle async I/O jobs
- Can’t use GPU
- You can’t have CPU requirements
- Not great to scale
But we like the queuing with RabbitMQ (persistant and reliable).
Our use of Ray
After some drawback with ray core only ( [CORE] actor atexit is not called when SIGTERM receive · Issue #50004 · ray-project/ray ) where I couln’t make it work with proper cleannup. We went with Serve + core.
So we have our API receive a user file it will save the file to a S3 and send a task wiht the file id to RabbitMQ.
Then we have a pool of worker that will poll the job and process it. We created 2 types of workers :
- Hight I/O bound workers (for llm calls mainly)
- Mixed Async and CPU workloads
For example for the mixed workload the worker is something like :
class AbstractWorkerDB(ABC, Generic[T]):
"""
Abstract base worker class that processes individual tasks
You need to define the task_class attribute in the concrete implementation
"""
task_class: Type[T]
def __init__(self, worker_name: str):
self.name = worker_name
self.running = True
if not jpype.isJVMStarted():
jpype.startJVM()
atexit.register(self.__del__)
async def _async_init(self):
"""Internal async initialization method"""
self.running = True
self.logfire = logfire.configure(
service_name=f"worker_{self.name}",
scrubbing=logfire.ScrubbingOptions(
extra_patterns=extra_patterns,
callback=scrubbing_callback,
),
)
with self.logfire.span(f"initialize_worker_{self.name}_resources"):
self._task_adapter = TypeAdapter(self.task_class)
# Create semaphore to limit the number of concurrent tasks
self.semaphore = asyncio.Semaphore(RAY_WORKER_CONCURENT_TASKS)
# RabbitMQ connection for each worker
self.logfire.info(
f"Connecting to RabbitMQ at {RAY_RABBITMQ_ADDRESS} and using queue {self.task_class.queue_name()}"
)
self.connection = await aio_pika.connect_robust(RAY_RABBITMQ_ADDRESS)
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=RAY_WORKER_CONCURENT_TASKS)
self.queue = await self.channel.declare_queue(
str(self.task_class.queue_name()), durable=True
)
# Declare dead letter queue for failed tasks
self.dead_letter_queue = await self.channel.declare_queue(
FAILED_TASKS_QUEUE_NAME, durable=True
)
# setup db acess
self.db_engine: AsyncEngine = create_async_engine(
DATABASE_URL,
pool_size=DB_SESSION_POOL_SIZE,
max_overflow=DB_SESSION_MAX_OVERFLOW,
pool_timeout=DB_SESSION_POOL_TIMEOUT,
)
self.db_session: async_sessionmaker[AsyncSession] = async_sessionmaker(
self.db_engine, expire_on_commit=False, class_=AsyncSession
)
self.logfire.info("Database pool initialized for worker")
# TEST
import jpype
if not jpype.isJVMStarted():
jpype.startJVM()
self.logfire.info(f"Worker {self.name} initialized")
@abstractmethod
async def process_task_logic(self, task: T) -> None:
"""
Abstract method that should be implemented by concrete workers
to define specific task processing logic
"""
pass
async def process_task(self, task: T) -> None:
"""Process a single message with semaphore control and task lifecycle management"""
task_id = task.task_id
try:
async with self.semaphore:
# Update task status and start time
task.status = TaskStatus.RUNNING
task.started_at = datetime.datetime.now()
# Execute concrete implementation
await self.process_task_logic(task)
# Update task completion
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.datetime.now()
except Exception as e:
error_message = str(e)
self.logfire.error(
f"{task_id} : Task failed: {error_message} and send to failed queue",
exc_info=True,
)
task.status = TaskStatus.FAILED
task.completed_at = datetime.datetime.now()
failure_data = FailedTask(
task_id=task_id,
original_message=task,
error=error_message,
)
await self.channel.default_exchange.publish(
aio_pika.Message(
body=failure_data.model_dump_json().encode(),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
),
routing_key=FAILED_TASKS_QUEUE_NAME,
)
async def run(self):
"""Main worker loop"""
self.active_tasks = set()
await self._async_init()
try:
async with self.queue.iterator() as queue_iter:
while True:
if not self.running:
break
try:
message = await queue_iter.__anext__()
except StopAsyncIteration:
break
async with message.process():
try:
task = self._task_adapter.validate_json(
message.body.decode()
)
task_obj = asyncio.create_task(self.process_task(task))
self.active_tasks.add(task_obj)
task_obj.add_done_callback(self.active_tasks.discard)
except Exception as e:
self.logfire.error(
f"Error processing message: {str(e)}", exc_info=True
)
finally:
with self.logfire.span(f"worker_{self.name}_run_thread"):
self.logfire.info(f"Worker {self.name} run thread finished")
async def shutdown_worker(self):
"""Wait for all tasks to complete and cleanup the worker, you need to call actor.__ray_terminate__.remote()"""
with self.logfire.span(f"shutdown_worker_{self.name}"):
self.running = False
if self.active_tasks:
self.logfire.info(
f"Waiting for {len(self.active_tasks)} tasks to complete"
)
try:
await asyncio.gather(*self.active_tasks, return_exceptions=True)
except Exception as e:
self.logfire.error(
f"Error during task cleanup: {str(e)}", exc_info=True
)
if self.connection:
await self.connection.close()
self.logfire.info("RabbitMQ connection disposed for worker")
if self.db_engine:
await self.db_engine.dispose()
self.logfire.info("Database pool disposed for worker")
return
As you can see we have some particularities :
- We need to maintain db pool of connection to avoid recreating it each time we process a task
- logfire for observability that we also need to keep running thought tasks
- async jobs so we pull the top 30 task and processs them async
this way we have a async and specific worker
Then when we init the app :
async def cleanup():
logfire.info(f"Cleanup triggered for FastAPI")
try:
logfire.info("Starting Ray workers cleanup...")
# First shutdown all workers in parallel
logfire.info(f"Shutting down workers : {len(workers)} : {workers}")
await asyncio.gather(*[worker.shutdown_worker.remote() for worker in workers]) # type: ignore
for worker in workers:
worker.__ray_terminate__.remote() # type: ignore
logfire.info("All workers cleaned up")
await cleanup_db()
except Exception as e:
logfire.error(f"Error during cleanup: {str(e)}")
raise
async def init_ray():
if not ray.is_initialized() and ENVIRONMENT == "prod":
ray.init(address="auto", namespace="xxx")
elif not ray.is_initialized() and ENVIRONMENT == "local":
ray.init(address="auto", ignore_reinit_error=True, namespace="xxx")
async def init_workers():
global workers
# Get remaining available CPUs in the cluster (not used by Serve)
available_cpus = int(
ray.available_resources().get("CPU", 1)
) # Default to 1 if not found
logfire.info(f"Available (unused) cluster CPUs: {available_cpus}")
# Use available CPUs, leaving some headroom for other operations
num_workers = max(1, available_cpus - 1)
workers = [TestWorker.remote(f"worker_{i}") for i in range(num_workers)]
logfire.info(f"Initialized workers: {len(workers)}")
for worker in workers:
worker.run.remote() # type: ignore
@asynccontextmanager
async def lifespan(app: FastAPI):
# Rest of startup code
await init_db()
setup_retriever()
await init_ray()
await initialize_task_manager()
await init_workers()
try:
yield
finally:
# Always perform cleanup when the app is shutting down
await cleanup()
app = FastAPI(lifespan=lifespan)
# Ray Serve deployment
@serve.deployment(
ray_actor_options={"num_cpus": 1},
num_replicas=1,
health_check_period_s=10,
health_check_timeout_s=10,
max_ongoing_requests=1000,
max_queued_requests=-1,
graceful_shutdown_timeout_s=240,
name="xxxx",
version="v1",
)
@serve.ingress(app)
class xxxxxApp:
def __init__(self):
print("Callisto backend is starting...")
logfire.configure(
service_name="callisto",
scrubbing=logfire.ScrubbingOptions(
extra_patterns=extra_patterns,
callback=scrubbing_callback,
),
)
logfire.instrument_fastapi(app)
app.include_router(table_query_template_router)
if ENVIRONMENT == "local":
app.include_router(admin_router)
# Add any necessary middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logfire.info("Callisto Backend initialized !!")
def __del__(self):
print("Callisto backend deleted")
@app.get("/task/{nb_task}")
async def test(self, nb_task: int):
logfire.info(f"Test task {nb_task}")
task = TestTask()
task_manager = get_task_manager()
for i in range(nb_task):
task_id = await task_manager.submit_task(task)
logfire.info(f"Task {i} submitted with ID: {task_id}")
return {"task_id": task_id}
@app.get("/health")
async def health_check(self):
return {"status": "ok !"}
xxxx = xxxxApp.bind()
This way we init the workers and the API and all is in the Ray cluster. And then as we use RabbitMQ task are distributed to the first available worker thus maximizing load on the cluster.
Using serve make it way better to handle destroying workers and closing fastAPI properly.
Only drawback is the fast reload but thanks to @ volks73 we had something working : [Serve] Reloading not working with cluster · Issue #40553 · ray-project/ray
Next steps is maybe add task persistance in db with state if they crashed or retries…
Some questions ?
- As you can see we need some global vars,
workers
ortask_manager
but is it OK to have such global var in the API replica ? We also have those global vars for the db then we usedb: AsyncSession = Depends(get_db)
in the fastAPI route function. - Is it a good way of using Ray workers, are we going to have so issues down the line ?
- Do you see things we could do better ?
Thanks to all !