Ray Serve: Ray Serve vs Regular Web server Performance?

Context

Our production ML model serving platform is currently serving ~5 models(TF, PyTorch, Scikit etc). While adding couple of more models we are hitting serious performance degradation. It is one particular model from Flair[flair/ner-english-large · Hugging Face] that we are seeing bottleneck during its tagger.predict(sentence) call.

After simulating exactly how Ray could be used within our present system, I realized that somehow I am not seeing much performance improvement.

I can definitely get some expert suggestion on this.






Ray Served Model - I have 2 files one for model serving and another one as service interface and model consumer.

  1. FlairModel.py - Model serving
# FlairModel.py
import ray
from ray import serve
from flair.models import SequenceTagger
from flair.data import Sentence
from flair.tokenization import SpaceTokenizer
from nltk.tokenize import sent_tokenize


ray.init(address='auto', namespace="serve", ignore_reinit_error=True)
serve.start(detached=True)


@serve.deployment(name="FlairModel", num_replicas=1)
class FlairModel:
    def __init__(self) -> None:
        self.tagger = SequenceTagger.load("flair/ner-english-large")
        self.flair_thresholding = 0.80
        self.tokenizer = SpaceTokenizer()


    async def __call__(self, request):
        sent_list = sent_tokenize(request)
        sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]

        [self.tagger.predict(each, mini_batch_size=64) for each in sentences]

        resp = []
        for each in sentences:
            for entity in each.get_spans('ner'):
                if entity.tag == 'LOC' and entity.score > self.flair_thresholding:
                    resp.append(entity.text)

        return {"predictions": list(set(resp))}


FlairModel.deploy()




  1. deploy_ray.py - Ray serve + FastAPI web server consuming deployed models.

# deploy_ray.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import ray
from ray import serve


app = FastAPI()
ray.init(address="ray://127.0.0.1:10001", ignore_reinit_error=True)
serve.start(detached=True)



class InRequest(BaseModel):
    text: str


@app.get('/')
def api_root():
    return 'Welcome !!!'


@serve.deployment(route_prefix="/")
@serve.ingress(app)
class APIWrapper:
    pass

@serve.deployment(route_prefix='/api')
@serve.ingress(app)
class InspectFlair:
    @app.post('/entities')
    def inspect(self, request: InRequest):
        if len(request.text) == 0:
            raise HTTPException(status_code=400, detail="'text' not found!")
        try:

            data = request.text
            flair_handle = serve.get_deployment("FlairModel").get_handle()
            future = flair_handle.remote(data)
            response = ray.get(future)
            return {'response': response}
        except ValueError:
            return {'error': 'Got Err'}, 400


APIWrapper.deploy()
InspectFlair.deploy()

  1. Ray cluster is within my local machine started as ray start --head



Regular FastAPI Served Model

  1. FlairModel.py - This code is exactly same as one used in Ray serve example with only change being model gets preload on startup event.
#FlairModel.py - for FastAPI demo
from flair.data import Sentence
from flair.tokenization import SpaceTokenizer
from nltk.tokenize import sent_tokenize


class FlairModel:
    def __init__(self, tagger) -> None:
        self.tagger = tagger
        self.flair_thresholding = 0.80
        self.tokenizer = SpaceTokenizer()


    def predict(self, request):
        sent_list = sent_tokenize(request)
        sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]
        [self.tagger.predict(each, mini_batch_size=128) for each in sentences]

        resp = []
        for each in sentences:
            for entity in each.get_spans('ner'):
                if entity.tag == 'LOC' and entity.score > self.flair_thresholding:
                    resp.append(entity.text)
        return {"predictions": list(set(resp))}




  1. fastapi_demo.py - This consumes individual Model instances.
# fastapi_demo.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from FlairModel import FlairModel
from flair.models import SequenceTagger

app = FastAPI()
models = {}


class InRequest(BaseModel):
    text: str


@app.get('/')
def api_root():
    return 'Welcome !!!'


@app.on_event("startup")
def load_flair_model():
    models['tagger'] = SequenceTagger.load("flair/ner-english-large")


@app.post('/api/entities')
def inspect(request: InRequest):
    if len(request.text) == 0:
        raise HTTPException(status_code=400, detail="'text' not found!")
    try:
        data = request.text
        flair_handle = FlairModel(models['tagger'])
        response = flair_handle.predict(data)
        return {'response': response}
    except ValueError:
        return {'error': 'Got Err'}, 400


def app_init():
    return app

  1. FastAPI is deployed with gunicorn-uvicorn worker processes.
    gunicorn --workers=1 --worker-class=uvicorn.workers.UvicornWorker --bind 0.0.0.0:3000 'fastapi_demo:app_init()' --timeout 1000








Performance Check - For quick evaluation I am using a test sentence having character length of 491 (I also performed some elaborated load tests with multiple of sentences randomly selected in each calls).

Ran regular curl call via httpstat utility to showcase performance.

  1. Ray served Models - First call is generally slow but subsequent calls are improved.
❯ httpstat http://0.0.0.0:8000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:8000 from 127.0.0.1:54065

HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:07:18 GMT
server: uvicorn
content-length: 68
content-type: application/json

Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpkm791wr1

  DNS Lookup   TCP Connection   Server Processing   Content Transfer
[     1ms    |       0ms      |      5675ms       |        0ms       ]
             |                |                   |                  |
    namelookup:1ms            |                   |                  |
                        connect:1ms               |                  |
                                      starttransfer:5676ms           |
                                                                 total:5676ms


❯ httpstat http://0.0.0.0:8000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:8000 from 127.0.0.1:54071

HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:07:29 GMT
server: uvicorn
content-length: 68
content-type: application/json

Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpxc57mp7_

  DNS Lookup   TCP Connection   Server Processing   Content Transfer
[     1ms    |       1ms      |      1201ms       |        0ms       ]
             |                |                   |                  |
    namelookup:1ms            |                   |                  |
                        connect:2ms               |                  |
                                      starttransfer:1203ms           |
                                                                 total:1203ms



  1. FastAPI served Models - First call is generally slow in this too but subsequent calls are improved.
❯ httpstat http://0.0.0.0:3000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:3000 from 127.0.0.1:55317

HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:36:57 GMT
server: uvicorn
content-length: 68
content-type: application/json

Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpliub11ip

  DNS Lookup   TCP Connection   Server Processing   Content Transfer
[     1ms    |       0ms      |      5368ms       |        0ms       ]
             |                |                   |                  |
    namelookup:1ms            |                   |                  |
                        connect:1ms               |                  |
                                      starttransfer:5369ms           |
                                                                 total:5369ms

❯ httpstat http://0.0.0.0:3000/api/entities -X POST -d '{"text":"Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures. Spain has reported its highest number of daily cases since the start of the pandemic and France has warned daily cases there could soon pass 100,000. French Health Minister Olivier Véran said the increase in daily infections in the country, currently at about 70,000, would be driven by the Omicron variant, which he said was likely to become the dominant variant by early January"}' -H "Content-Type:application/json"
Connected to 127.0.0.1:3000 from 127.0.0.1:55335

HTTP/1.1 200 OK
date: Thu, 23 Dec 2021 01:37:30 GMT
server: uvicorn
content-length: 68
content-type: application/json

Body stored in: /var/folders/_h/5_3695q9343ctmsgltf3grq80000gp/T/tmpai6n7wve

  DNS Lookup   TCP Connection   Server Processing   Content Transfer
[     1ms    |       0ms      |      1184ms       |        0ms       ]
             |                |                   |                  |
    namelookup:1ms            |                   |                  |
                        connect:1ms               |                  |
                                      starttransfer:1185ms           |
                                                                 total:1185ms


I am sure there must be some ways to improve this situation. Experts please guide me in improving this.

hi @lihost , thanks for the high quality context and repro steps for your question :slight_smile: I felt this could be a good pattern other folks will run into as well thus valuable to document it in our forum.

I don’t know the FlairModel you’re using very well, but i spot an opportunity that you can try to better parallelize your self.tagger.predict calls since they seem to happen in a loop with 64-128 batch size, assuming single forward pass consumes most of the time, this is likely contributing most of your e2e latency but ray is really good at parallelizing & scaling.

https://docs.ray.io/en/master/ray-design-patterns/ray-get-loop.html

I would consider using “Tree of Actors” pattern Pattern: Tree of actors — Ray v2.0.0.dev0 that spawn more replicas of your SequenceTagger model that simply takes the each input and return response, something like

@serve.deployment(name="FlairModel", num_replicas=10)
class FlairModel:
    def __init__(self) -> None:
        self.tagger = SequenceTagger.load("flair/ner-english-large")
  
    async def __call__(self, request):
        return self.tagger.predict(request, mini_batch_size=64)

While keep fewer instances of tokenizer to preprocess your input and split into multiple, parallelizable calls for each token in each sentence:

@serve.deployment(name="Tokenizer", num_replicas=2)
class FlairTokenizer:
    def __init__(self) -> None:
        self.flair_thresholding = 0.80
        self.tokenizer = SpaceTokenizer()
 
        self.model_handle = FlairModel.get_handle() # Call this only after you've deployed all models and this handle is valid

   async def __call__(self, request):
        sent_list = sent_tokenize(request)
        sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]

        refs = []
        for each in sentences:
            refs.append(self.model_handle.remote(each))
        
        returns = ray.get(refs)
        # Do your filtering logic afterwards

Depending on the workload, you might also find serve batching helpful: Performance Tuning — Ray v2.0.0.dev0

There might be some syntax nits for code above but the point is to split sentence tokenizing and running prediction into separate deployments, where replicas running predict can be horizontally scaled, and tokenize + aggregate + filter can be considered non-IO / CPU intensive thus just sending and waiting parallelized calls in batch.

Hi @jiaodong , thank you for such succinct explanation. This indeed serves the purpose.