Context
Our production ML model serving platform is currently serving ~5 models(TF, PyTorch, Scikit etc). While adding couple of more models we are hitting serious performance degradation. It is one particular model from Flair[flair/ner-english-large · Hugging Face] that we are seeing bottleneck during its tagger.predict(sentence)
call.
After simulating exactly how Ray could be used within our present system, I realized that somehow I am not seeing much performance improvement.
I can definitely get some expert suggestion on this.
Ray Served Model - I have 2 files one for model serving and another one as service interface and model consumer.
-
FlairModel.py
- Model serving
# FlairModel.py
import ray
from ray import serve
from flair.models import SequenceTagger
from flair.data import Sentence
from flair.tokenization import SpaceTokenizer
from nltk.tokenize import sent_tokenize
ray.init(address='auto', namespace="serve", ignore_reinit_error=True)
serve.start(detached=True)
@serve.deployment(name="FlairModel", num_replicas=1)
class FlairModel:
def __init__(self) -> None:
self.tagger = SequenceTagger.load("flair/ner-english-large")
self.flair_thresholding = 0.80
self.tokenizer = SpaceTokenizer()
async def __call__(self, request):
sent_list = sent_tokenize(request)
sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]
[self.tagger.predict(each, mini_batch_size=64) for each in sentences]
resp = []
for each in sentences:
for entity in each.get_spans('ner'):
if entity.tag == 'LOC' and entity.score > self.flair_thresholding:
resp.append(entity.text)
return {"predictions": list(set(resp))}
FlairModel.deploy()
-
deploy_ray.py
- Ray serve + FastAPI web server consuming deployed models.
# deploy_ray.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import ray
from ray import serve
app = FastAPI()
ray.init(address="ray://127.0.0.1:10001", ignore_reinit_error=True)
serve.start(detached=True)
class InRequest(BaseModel):
text: str
@app.get('/')
def api_root():
return 'Welcome !!!'
@serve.deployment(route_prefix="/")
@serve.ingress(app)
class APIWrapper:
pass
@serve.deployment(route_prefix='/api')
@serve.ingress(app)
class InspectFlair:
@app.post('/entities')
def inspect(self, request: InRequest):
if len(request.text) == 0:
raise HTTPException(status_code=400, detail="'text' not found!")
try:
data = request.text
flair_handle = serve.get_deployment("FlairModel").get_handle()
future = flair_handle.remote(data)
response = ray.get(future)
return {'response': response}
except ValueError:
return {'error': 'Got Err'}, 400
APIWrapper.deploy()
InspectFlair.deploy()
- Ray cluster is within my local machine started as
ray start --head
Regular FastAPI Served Model
-
FlairModel.py
- This code is exactly same as one used in Ray serve example with only change being model gets preload on startup event.
#FlairModel.py - for FastAPI demo
from flair.data import Sentence
from flair.tokenization import SpaceTokenizer
from nltk.tokenize import sent_tokenize
class FlairModel:
def __init__(self, tagger) -> None:
self.tagger = tagger
self.flair_thresholding = 0.80
self.tokenizer = SpaceTokenizer()
def predict(self, request):
sent_list = sent_tokenize(request)
sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]
[self.tagger.predict(each, mini_batch_size=128) for each in sentences]
resp = []
for each in sentences:
for entity in each.get_spans('ner'):
if entity.tag == 'LOC' and entity.score > self.flair_thresholding:
resp.append(entity.text)
return {"predictions": list(set(resp))}
-
fastapi_demo.py
- This consumes individual Model instances.
# fastapi_demo.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from FlairModel import FlairModel
from flair.models import SequenceTagger
app = FastAPI()
models = {}
class InRequest(BaseModel):
text: str
@app.get('/')
def api_root():
return 'Welcome !!!'
@app.on_event("startup")
def load_flair_model():
models['tagger'] = SequenceTagger.load("flair/ner-english-large")
@app.post('/api/entities')
def inspect(request: InRequest):
if len(request.text) == 0:
raise HTTPException(status_code=400, detail="'text' not found!")
try:
data = request.text
flair_handle = FlairModel(models['tagger'])
response = flair_handle.predict(data)
return {'response': response}
except ValueError:
return {'error': 'Got Err'}, 400
def app_init():
return app
- FastAPI is deployed with gunicorn-uvicorn worker processes.
gunicorn --workers=1 --worker-class=uvicorn.workers.UvicornWorker --bind 0.0.0.0:3000 'fastapi_demo:app_init()' --timeout 1000
Performance Check - For quick evaluation I am using a test sentence having character length of 491 (I also performed some elaborated load tests with multiple of sentences randomly selected in each calls).
Ran regular curl
call via httpstat
utility to showcase performance.
- Ray served Models - First call is generally slow but subsequent calls are improved.
❯ httpstat http://0.0.0.0:8000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:8000 from 127.0.0.1:54065
HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:07:18 GMT
server: uvicorn
content-length: 68
content-type: application/json
Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpkm791wr1
DNS Lookup TCP Connection Server Processing Content Transfer
[ 1ms | 0ms | 5675ms | 0ms ]
| | | |
namelookup:1ms | | |
connect:1ms | |
starttransfer:5676ms |
total:5676ms
❯ httpstat http://0.0.0.0:8000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:8000 from 127.0.0.1:54071
HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:07:29 GMT
server: uvicorn
content-length: 68
content-type: application/json
Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpxc57mp7_
DNS Lookup TCP Connection Server Processing Content Transfer
[ 1ms | 1ms | 1201ms | 0ms ]
| | | |
namelookup:1ms | | |
connect:2ms | |
starttransfer:1203ms |
total:1203ms
- FastAPI served Models - First call is generally slow in this too but subsequent calls are improved.
❯ httpstat http://0.0.0.0:3000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:3000 from 127.0.0.1:55317
HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:36:57 GMT
server: uvicorn
content-length: 68
content-type: application/json
Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpliub11ip
DNS Lookup TCP Connection Server Processing Content Transfer
[ 1ms | 0ms | 5368ms | 0ms ]
| | | |
namelookup:1ms | | |
connect:1ms | |
starttransfer:5369ms |
total:5369ms
❯ httpstat http://0.0.0.0:3000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:3000 from 127.0.0.1:55335
HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:37:30 GMT
server: uvicorn
content-length: 68
content-type: application/json
Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpai6n7wve
DNS Lookup TCP Connection Server Processing Content Transfer
[ 1ms | 0ms | 1184ms | 0ms ]
| | | |
namelookup:1ms | | |
connect:1ms | |
starttransfer:1185ms |
total:1185ms
I am sure there must be some ways to improve this situation. Experts please guide me in improving this.