Is it possible to use multi-input models with BatchPredictor?
Consider the case where a model takes two inputs into its forward()
. The model will complain that it never receives its second input, x2
in this instance.
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor
checkpoint = TorchCheckpoint.from_model(best_ckpt['model'])
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
predictions = predictor.predict(data=test_set, feature_columns=["x1", "x2"])