hi @lihost , thanks for the high quality context and repro steps for your question I felt this could be a good pattern other folks will run into as well thus valuable to document it in our forum.
I don’t know the FlairModel you’re using very well, but i spot an opportunity that you can try to better parallelize your self.tagger.predict
calls since they seem to happen in a loop with 64-128 batch size, assuming single forward pass consumes most of the time, this is likely contributing most of your e2e latency but ray is really good at parallelizing & scaling.
https://docs.ray.io/en/master/ray-design-patterns/ray-get-loop.html
I would consider using “Tree of Actors” pattern Pattern: Tree of actors — Ray v2.0.0.dev0 that spawn more replicas of your SequenceTagger model that simply takes the each
input and return response, something like
@serve.deployment(name="FlairModel", num_replicas=10)
class FlairModel:
def __init__(self) -> None:
self.tagger = SequenceTagger.load("flair/ner-english-large")
async def __call__(self, request):
return self.tagger.predict(request, mini_batch_size=64)
While keep fewer instances of tokenizer to preprocess your input and split into multiple, parallelizable calls for each token in each sentence:
@serve.deployment(name="Tokenizer", num_replicas=2)
class FlairTokenizer:
def __init__(self) -> None:
self.flair_thresholding = 0.80
self.tokenizer = SpaceTokenizer()
self.model_handle = FlairModel.get_handle() # Call this only after you've deployed all models and this handle is valid
async def __call__(self, request):
sent_list = sent_tokenize(request)
sentences = [Sentence(each, use_tokenizer=self.tokenizer) for each in sent_list]
refs = []
for each in sentences:
refs.append(self.model_handle.remote(each))
returns = ray.get(refs)
# Do your filtering logic afterwards
Depending on the workload, you might also find serve batching helpful: Performance Tuning — Ray v2.0.0.dev0
There might be some syntax nits for code above but the point is to split sentence tokenizing and running prediction into separate deployments, where replicas running predict can be horizontally scaled, and tokenize + aggregate + filter can be considered non-IO / CPU intensive thus just sending and waiting parallelized calls in batch.