Hi!
I’m attempting to fix existing unit tests for my service, which now integrates Ray Serve via FastAPI. My test is really simple, and it’s testing that the FastAPI deployment is using API keys;
import io
from fastapi.testclient import TestClient
from api import api
from config import get_settings, Settings
client = TestClient(api)
def test_check_api():
TEST_API_KEY = "test-api-key"
api.dependency_overrides[get_settings] = lambda: Settings(api_keys=[TEST_API_KEY])
# Assert that the call fails with no API key. HTTP 403 must
# be returned in that case.
response = client.post("/api/test")
assert response.status_code == 403
api.dependency_overrides = {}
My main deployment is defined like this:
api = FastAPI()
@serve.deployment()
@serve.ingress(api)
class FastAPIDeployment:
def __init__(self, model_deployment):
self._model_deployment = model_deployment
@api.post("/api/test")
async def check(
self,
api_key: APIKey = Depends(auth.get_api_key),
):
model_ref = await self._model_deployment.remote("some data")
result = await model_ref
return result
When running the test, without any modification, I’m getting the following output:
tests\test_api.py F [100%]
============================================================= FAILURES =============================================================
__________________________________________________________ test_check_api __________________________________________________________
def test_check_api():
TEST_API_KEY = "test-api-key"
api.dependency_overrides[get_settings] = lambda: Settings(api_keys=[TEST_API_KEY])
# Assert that the call fails with no API key. HTTP 403 must
# be returned in that case.
> response = client.post("/api/check")
tests\test_api.py:17:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv\lib\site-packages\requests\sessions.py:635: in post
return self.request("POST", url, data=data, json=json, **kwargs)
.venv\lib\site-packages\starlette\testclient.py:473: in request
return super().request(
.venv\lib\site-packages\requests\sessions.py:587: in request
resp = self.send(prep, **send_kwargs)
.venv\lib\site-packages\requests\sessions.py:701: in send
r = adapter.send(request, **kwargs)
.venv\lib\site-packages\starlette\testclient.py:267: in send
raise exc
.venv\lib\site-packages\starlette\testclient.py:264: in send
portal.call(self.app, scope, receive, send)
.venv\lib\site-packages\anyio\from_thread.py:283: in call
return cast(T_Retval, self.start_task_soon(func, *args).result())
C:\Users\dexte\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py:458: in result
return self.__get_result()
C:\Users\dexte\AppData\Local\Programs\Python\Python310\lib\concurrent\futures\_base.py:403: in __get_result
raise self._exception
.venv\lib\site-packages\anyio\from_thread.py:219: in _call_func
retval = await retval
.venv\lib\site-packages\fastapi\applications.py:270: in __call__
await super().__call__(scope, receive, send)
.venv\lib\site-packages\starlette\applications.py:124: in __call__
await self.middleware_stack(scope, receive, send)
.venv\lib\site-packages\starlette\middleware\errors.py:184: in __call__
raise exc
.venv\lib\site-packages\starlette\middleware\errors.py:162: in __call__
await self.app(scope, receive, _send)
.venv\lib\site-packages\starlette\middleware\exceptions.py:75: in __call__
raise exc
.venv\lib\site-packages\starlette\middleware\exceptions.py:64: in __call__
await self.app(scope, receive, sender)
.venv\lib\site-packages\fastapi\middleware\asyncexitstack.py:21: in __call__
raise e
.venv\lib\site-packages\fastapi\middleware\asyncexitstack.py:18: in __call__
await self.app(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:680: in __call__
await route.handle(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:275: in handle
await self.app(scope, receive, send)
.venv\lib\site-packages\starlette\routing.py:65: in app
response = await func(request)
.venv\lib\site-packages\fastapi\routing.py:225: in app
solved_result = await solve_dependencies(
.venv\lib\site-packages\fastapi\dependencies\utils.py:535: in solve_dependencies
solved = await run_in_threadpool(call, **sub_values)
.venv\lib\site-packages\starlette\concurrency.py:41: in run_in_threadpool
return await anyio.to_thread.run_sync(func, *args)
.venv\lib\site-packages\anyio\to_thread.py:31: in run_sync
return await get_asynclib().run_sync_in_worker_thread(
.venv\lib\site-packages\anyio\_backends\_asyncio.py:937: in run_sync_in_worker_thread
return await future
.venv\lib\site-packages\anyio\_backends\_asyncio.py:867: in run
result = context.run(func, *args)
.venv\lib\site-packages\ray\serve\_private\http_util.py:186: in get_current_servable_instance
return serve.get_replica_context().servable_object
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
@PublicAPI(stability="beta")
def get_replica_context() -> ReplicaContext:
"""If called from a deployment, returns the deployment and replica tag.
A replica tag uniquely identifies a single replica for a Ray Serve
deployment at runtime. Replica tags are of the form
`<deployment_name>#<random letters>`.
Raises:
RayServeException: if not called from within a Ray Serve deployment.
Example:
>>> from ray import serve
>>> # deployment_name
>>> serve.get_replica_context().deployment # doctest: +SKIP
>>> # deployment_name#krcwoa
>>> serve.get_replica_context().replica_tag # doctest: +SKIP
"""
internal_replica_context = get_internal_replica_context()
if internal_replica_context is None:
> raise RayServeException(
"`serve.get_replica_context()` "
"may only be called from within a "
"Ray Serve deployment."
)
E ray.serve.exceptions.RayServeException: `serve.get_replica_context()` may only be called from within a Ray Serve deployment.
.venv\lib\site-packages\ray\serve\api.py:149: RayServeException
What am I doing wrong? Is there any guide/documentation on how to write unit tests when using Ray + pytest + FastAPI?