Ray Serve changes behavior of Flair Predictions

  1. Flair without Ray
from flair.data import Sentence
from flair.models import SequenceTagger

class FlairModel:
    def __init__(self):
        self.tagger = SequenceTagger.load("flair/ner-english-large")
    
    def __call__(self, request):
        return self.tagger.predict(request)
    

text = 'Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures.'

each = Sentence(text, use_tokenizer=True)
fm = FlairModel()
resp = fm(each)
print(each)

Output -

Sentence: "Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures ."   
[− Tokens: 15  − Token-Labels: "Germany <S-LOC> and Portugal <S-LOC> are among nations announcing post-Christmas <S-MISC> curbs and greater social distancing measures ."]



  1. Flair with Ray
from flair.data import Sentence
from flair.models import SequenceTagger
import ray
from ray import serve

ray.init(address='auto', namespace="serve", ignore_reinit_error=True)
serve.start(detached=True)



@serve.deployment(name="FlairModel", num_replicas=1)
class FlairModel:
    def __init__(self):
        self.tagger = SequenceTagger.load("flair/ner-english-large")
    
    def __call__(self, request):
        return self.tagger.predict(request)
    
FlairModel.deploy()


flair_model = FlairModel.get_handle()

text = 'Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures.'

each = Sentence(text, use_tokenizer=True)
fut = flair_model.remote(each)
ret = ray.get(fut)
print(each)

Output -

Sentence: "Germany and Portugal are among nations announcing post-Christmas curbs and greater social distancing measures ."   
[− Tokens: 15]



We can see that Flair is not getting **NER labels** while predicting via Ray-deployed-Flair. Please let me know if I am missing something here.

Got it resolved, it’s a trivial one.

@serve.deployment(name="FlairModel", num_replicas=1)
class FlairModel:
    def __init__(self):
        self.tagger = SequenceTagger.load("flair/ner-english-large")
  
    async def __call__(self, request):
        _ = self.tagger.predict(request)
        return request


FlairModel.deploy()