Writing unit tests: Serve + FastAPI + pytest

Hi!

I’m attempting to fix existing unit tests for my service, which now integrates Ray Serve via FastAPI. My test is really simple, and it’s testing that the FastAPI deployment is using API keys;

import io

from fastapi.testclient import TestClient

from api import api 
from config import get_settings, Settings

client = TestClient(api)


def test_check_api():
    TEST_API_KEY = "test-api-key"
    api.dependency_overrides[get_settings] = lambda: Settings(api_keys=[TEST_API_KEY])

    # Assert that the call fails with no API key. HTTP 403 must
    # be returned in that case.
    response = client.post("/api/test")
    assert response.status_code == 403

    api.dependency_overrides = {}

My main deployment is defined like this:

api = FastAPI()


@serve.deployment()
@serve.ingress(api)
class FastAPIDeployment:
    def __init__(self, model_deployment):
        self._model_deployment = model_deployment

    @api.post("/api/test")
    async def check(
        self,
        api_key: APIKey = Depends(auth.get_api_key),
    ):
        model_ref = await self._model_deployment.remote("some data")
        result = await model_ref

        return result

When running the test, without any modification, I’m getting the following output:

tests\test_api.py F                                                                                                           [100%]

============================================================= FAILURES =============================================================
__________________________________________________________ test_check_api __________________________________________________________

    def test_check_api():
        TEST_API_KEY = "test-api-key"
        api.dependency_overrides[get_settings] = lambda: Settings(api_keys=[TEST_API_KEY])

        # Assert that the call fails with no API key. HTTP 403 must
        # be returned in that case.
>       response = client.post("/api/check")

tests\test_api.py:17:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv\lib\site-packages\requests\sessions.py:635: in post
    return self.request("POST", url, data=data, json=json, **kwargs)
.venv\lib\site-packages\starlette\testclient.py:473: in request
    return super().request(
.venv\lib\site-packages\requests\sessions.py:587: in request
    resp = self.send(prep, **send_kwargs)
.venv\lib\site-packages\requests\sessions.py:701: in send
    r = adapter.send(request, **kwargs)
.venv\lib\site-packages\starlette\testclient.py:267: in send
    raise exc
.venv\lib\site-packages\starlette\testclient.py:264: in send
    portal.call(self.app, scope, receive, send)
.venv\lib\site-packages\anyio\from_thread.py:283: in call
    return cast(T_Retval, self.start_task_soon(func, *args).result())
C:\Users\dexte\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py:458: in result
    return self.__get_result()
C:\Users\dexte\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py:403: in __get_result
    raise self._exception
.venv\lib\site-packages\anyio\from_thread.py:219: in _call_func
    retval = await retval
.venv\lib\site-packages\fastapi\applications.py:270: in __call__
    await super().__call__(scope, receive, send)
.venv\lib\site-packages\starlette\applications.py:124: in __call__
    await self.middleware_stack(scope, receive, send)
.venv\lib\site-packages\starlette\middleware\errors.py:184: in __call__
    raise exc
.venv\lib\site-packages\starlette\middleware\errors.py:162: in __call__
    await self.app(scope, receive, _send)
.venv\lib\site-packages\starlette\middleware\exceptions.py:75: in __call__
    raise exc
.venv\lib\site-packages\starlette\middleware\exceptions.py:64: in __call__
    await self.app(scope, receive, sender)
.venv\lib\site-packages\fastapi\middleware\asyncexitstack.py:21: in __call__
    raise e
.venv\lib\site-packages\fastapi\middleware\asyncexitstack.py:18: in __call__
    await self.app(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:680: in __call__
    await route.handle(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:275: in handle
    await self.app(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:65: in app
    response = await func(request)
.venv\lib\site-packages\fastapi\routing.py:225: in app
    solved_result = await solve_dependencies(
.venv\lib\site-packages\fastapi\dependencies\utils.py:535: in solve_dependencies
    solved = await run_in_threadpool(call, **sub_values)
.venv\lib\site-packages\starlette\concurrency.py:41: in run_in_threadpool
    return await anyio.to_thread.run_sync(func, *args)
.venv\lib\site-packages\anyio\to_thread.py:31: in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
.venv\lib\site-packages\anyio\_backends\_asyncio.py:937: in run_sync_in_worker_thread
    return await future
.venv\lib\site-packages\anyio\_backends\_asyncio.py:867: in run
    result = context.run(func, *args)
.venv\lib\site-packages\ray\serve\_private\http_util.py:186: in get_current_servable_instance
    return serve.get_replica_context().servable_object
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    @PublicAPI(stability="beta")
    def get_replica_context() -> ReplicaContext:
        """If called from a deployment, returns the deployment and replica tag.

        A replica tag uniquely identifies a single replica for a Ray Serve
        deployment at runtime.  Replica tags are of the form
        `<deployment_name>#<random letters>`.

        Raises:
            RayServeException: if not called from within a Ray Serve deployment.

        Example:
            >>> from ray import serve
            >>> # deployment_name
            >>> serve.get_replica_context().deployment # doctest: +SKIP
            >>> # deployment_name#krcwoa
            >>> serve.get_replica_context().replica_tag # doctest: +SKIP
        """
        internal_replica_context = get_internal_replica_context()
        if internal_replica_context is None:
>           raise RayServeException(
                "`serve.get_replica_context()` "
                "may only be called from within a "
                "Ray Serve deployment."
            )
E           ray.serve.exceptions.RayServeException: `serve.get_replica_context()` may only be called from within a Ray Serve deployment.

.venv\lib\site-packages\ray\serve\api.py:149: RayServeException

What am I doing wrong? Is there any guide/documentation on how to write unit tests when using Ray + pytest + FastAPI?