Bug in Ray TransformerPredictor.from_checkpoint

TL;DR: TransformerPredictor.from_checkpoint is calling Pipeline with checkpoint_path instead of actual model instance.

I created a custom HF Transformer Pipeline, and it is working correctly on its own.

I then tried creating a TransformerPredictor on top of it by passing a TransformersCheckpoint and the custom Pipeline class:

from ray.train.huggingface import TransformersCheckpoint, TransformersPredictor

# Loading a trained HF transformer model
checkpoint = TransformersCheckpoint.from_directory("model_base_100_pages_10_epochs_3_classes_best/")

predictor = TransformersPredictor.from_checkpoint(checkpoint, pipeline_cls=PageTypeClassificationPipeline)

But this is giving me error, it is clear from error that from_checkpoint is passing the checkpoint_path as argument to Pipeline class constructor which expects an actual model (Check HF source code here) not a path.
The HF Pipeline is different from HF pipeline wrapper that can take model names as strings to construct actual pipelines.

โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <module>:1                                                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ โฑ 1 TransformersPredictor.from_checkpoint(checkpoint, pipeline_cls=PageTypeClassificationPip     โ”‚
โ”‚   2                                                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.8/site-packages/ray/train/huggingface/transformers/transformers_predictor โ”‚
โ”‚ .py:148 in from_checkpoint                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   145 โ”‚   โ”‚   with checkpoint.as_directory() as checkpoint_path:                                 โ”‚
โ”‚   146 โ”‚   โ”‚   โ”‚   # Tokenizer will be loaded automatically (no need to specify                   โ”‚
โ”‚   147 โ”‚   โ”‚   โ”‚   # `tokenizer=checkpoint_path`)                                                 โ”‚
โ”‚ โฑ 148 โ”‚   โ”‚   โ”‚   pipeline = pipeline_cls(model=checkpoint_path, **pipeline_kwargs)              โ”‚
โ”‚   149 โ”‚   โ”‚   return cls(                                                                        โ”‚
โ”‚   150 โ”‚   โ”‚   โ”‚   pipeline=pipeline,                                                             โ”‚
โ”‚   151 โ”‚   โ”‚   โ”‚   preprocessor=preprocessor,                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ in __init__:6                                                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    3                                                                                             โ”‚
โ”‚    4 class PageTypeClassificationPipeline(Pipeline):                                             โ”‚
โ”‚    5 โ”‚   def __init__(self, *args, **kwargs):                                                    โ”‚
โ”‚ โฑ  6 โ”‚   โ”‚   super().__init__(*args, **kwargs)                                                   โ”‚
โ”‚    7 โ”‚   โ”‚   self.processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")       โ”‚
โ”‚    8 โ”‚                                                                                           โ”‚
โ”‚    9 โ”‚   def _sanitize_parameters(self, top_k=None):                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.8/site-packages/transformers/pipelines/base.py:756 in __init__            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    753 โ”‚   โ”‚   **kwargs,                                                                         โ”‚
โ”‚    754 โ”‚   ):                                                                                    โ”‚
โ”‚    755 โ”‚   โ”‚   if framework is None:                                                             โ”‚
โ”‚ โฑ  756 โ”‚   โ”‚   โ”‚   framework, model = infer_framework_load_model(model, config=model.config)     โ”‚
โ”‚    757 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    758 โ”‚   โ”‚   self.task = task                                                                  โ”‚
โ”‚    759 โ”‚   โ”‚   self.model = model                                                                โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
AttributeError: 'str' object has no attribute 'config' 

To prove the point, I tried creating the TransformerPredictor from the constructor instead of from_checkpoint and it worked and I could do prediction. However, to use BatchPredictor, it forces using TransformerPredictor.from_checkpoint under the hood and I canโ€™t get around it.

# This code works
from ray.train.huggingface import TransformersCheckpoint, TransformersPredictor

checkpoint= TransformersCheckpoint.from_directory("model_base_100_pages_10_epochs_3_classes_best/")
model = checkpoint.get_model(MarkupLMForSequenceClassification)
predictor = TransformersPredictor(pipeline=PageTypeClassificationPipeline(model=model))

I also created a custom MyTransformersPredictor by overriding TransformersPredictorโ€™s from_checkpoint as follows and it worked.

Original:

       with checkpoint.as_directory() as checkpoint_path:
            # Tokenizer will be loaded automatically (no need to specify
            # `tokenizer=checkpoint_path`)
            pipeline = pipeline_cls(model=checkpoint_path, **pipeline_kwargs)
        return cls(
            pipeline=pipeline,
            preprocessor=preprocessor,
            use_gpu=use_gpu,```

Overriden:

        
        # with checkpoint.as_directory() as checkpoint_path:
            # Tokenizer will be loaded automatically (no need to specify
            # `tokenizer=checkpoint_path`)
            
        model = checkpoint.get_model(MarkupLMForSequenceClassification)
        pipeline = pipeline_cls(model=model, **pipeline_kwargs)
        return cls(
            pipeline=pipeline,
            preprocessor=preprocessor,
            use_gpu=use_gpu,
        )

Pipeline code for reference

from transformers import Pipeline
from transformers import MarkupLMProcessor

class PageTypeClassificationPipeline(Pipeline):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")

    def _sanitize_parameters(self, top_k=None):
        postprocess_params = {}
        if top_k is not None:
            postprocess_params["top_k"] = top_k
        return {}, {}, postprocess_params

    def preprocess(self, inputs):
        return self.processor(inputs, padding="max_length", max_length=512, truncation=True, return_tensors="pt")

    def _forward(self, model_inputs):
        return self.model(**model_inputs)

    def postprocess(self, model_outputs, top_k=1):
        if top_k > self.model.config.num_labels:
            top_k = self.model.config.num_labels
        # Unnest batch
        probs = model_outputs.logits[0].softmax(-1)
        scores, ids = probs.topk(top_k)
        scores = scores.tolist()
        ids = ids.tolist()

        return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

Hi @MichaelAzmy,

thanks for bringin this up! This looks like a bug, and weโ€™ll address it asap.

Can you use your workaround in the meantime?

@MichaelAzmy Iโ€™ve filed a fix here:

Could you try this out to see if it solves your problem?

1 Like

Thanks for the quick fix. I checked the code and it looks good to me.
I will check when it is landed if everything worked as expected.