BatchPredictor and multi-input models

Is it possible to use multi-input models with BatchPredictor?

Consider the case where a model takes two inputs into its forward(). The model will complain that it never receives its second input, x2 in this instance.


from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

checkpoint = TorchCheckpoint.from_model(best_ckpt['model'])
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
predictions = predictor.predict(data=test_set, feature_columns=["x1", "x2"])

Hey @localh,

The BatchPredictor interface will be deprecated in the upcoming Ray 2.7 release. We’d recommend directly implementing this logic with Ray Datasets map_batches API, where you can express your model prediction logic yourself.

Can you take a look at End-to-end: Offline Batch Inference — Ray 3.0.0.dev0 and see if this helps?

1 Like