How to get predict class probabilities when using SklearnTrainer

I’m trying to train a RandomForestClassifier with the following codes in Ray and run inference:

import ray
from ray.air.config import ScalingConfig
from ray.train.sklearn import SklearnTrainer
from sklearn.ensemble import RandomForestClassifier
from ray.train.batch_predictor import BatchPredictor
from ray.train.sklearn import SklearnPredictor

# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")

# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

# Create a test dataset by dropping the target column.
test_dataset = valid_dataset.drop_columns(cols=["target"])

# Create a preprocessor to scale some columns.
from ray.data.preprocessors import StandardScaler

preprocessor = StandardScaler(columns=["mean radius", "mean texture"])

trainer = SklearnTrainer(
        estimator=RandomForestClassifier(),
        label_column="target",
        datasets={"train": train_dataset, "valid": valid_dataset},
        preprocessor=preprocessor,
        cv=5,
        scaling_config=ScalingConfig(trainer_resources={"CPU": 1}),
)

result = trainer.fit()


batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, SklearnPredictor)

predicted_output = batch_predictor.predict(test_dataset)
predicted_output.show()

but when I use batch_predictor.predict(test_dataset) , it will output predicted class like this:

{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}

I wonder is there anyway we can get class probabilities like predict_proba(X) in sklearn?

Hi @Xinbei_G, thanks for the question!

One naive way of doing this would be to override SklearnPredictor such that predict (or more specifically _predict_pandas) calls predict_proba instead of predict.

import ray
from ray.air.config import ScalingConfig
from ray.train.sklearn import SklearnTrainer
from sklearn.ensemble import RandomForestClassifier
from ray.train.batch_predictor import BatchPredictor
from ray.train.sklearn import SklearnPredictor


from typing import List, Optional, Union
import pandas as pd
from joblib import parallel_backend
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
from ray.train.sklearn._sklearn_utils import _set_cpu_params
from ray.util.joblib import register_ray

class SklearnProbabilityPredictor(SklearnPredictor):

    def _predict_pandas(
        self,
        data: "pd.DataFrame",
        feature_columns: Optional[Union[List[str], List[int]]] = None,
        num_estimator_cpus: Optional[int] = 1,
        **predict_kwargs,
    ) -> "pd.DataFrame":
        register_ray()

        if num_estimator_cpus:
            _set_cpu_params(self.estimator, num_estimator_cpus)

        if TENSOR_COLUMN_NAME in data:
            data = data[TENSOR_COLUMN_NAME].to_numpy()
            data = _unwrap_ndarray_object_type_if_needed(data)
            if feature_columns:
                data = data[:, feature_columns]
        elif feature_columns:
            data = data[feature_columns]

        with parallel_backend("ray", n_jobs=num_estimator_cpus):
            df = pd.DataFrame(self.estimator.predict_proba(data, **predict_kwargs))
        df.columns = (
            ["predictions"]
            if len(df.columns) == 1
            else [f"predictions_{i}" for i in range(len(df.columns))]
        )
        return df


# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")

# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

# Create a test dataset by dropping the target column.
test_dataset = valid_dataset.drop_columns(cols=["target"])

# Create a preprocessor to scale some columns.
from ray.data.preprocessors import StandardScaler

preprocessor = StandardScaler(columns=["mean radius", "mean texture"])

trainer = SklearnTrainer(
        estimator=RandomForestClassifier(),
        label_column="target",
        datasets={"train": train_dataset, "valid": valid_dataset},
        preprocessor=preprocessor,
        cv=5,
        scaling_config=ScalingConfig(trainer_resources={"CPU": 1}),
)

result = trainer.fit()

batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, SklearnProbabilityPredictor)

predicted_output = batch_predictor.predict(test_dataset)
predicted_output.show()

This would then output:

{'predictions_0': 0.47, 'predictions_1': 0.53}
{'predictions_0': 0.43, 'predictions_1': 0.57}
{'predictions_0': 0.01, 'predictions_1': 0.99}
{'predictions_0': 0.05, 'predictions_1': 0.95}
{'predictions_0': 1.0, 'predictions_1': 0.0}

Note that this was only run against the example provided, so it may need some additional post-processing of the returned values to be comprehensive!

cc @Yard1 for thoughts on how we can extend the API to support this.

I think we can either make it easier to subclass, or allow users to specify the name of the method used for prediction in the constructor (which is a pattern found in sklearn itself). I think the former would be more extensible, but I’d be happy with either.

Thanks for the solution! And it would be great if you can extend the API to support this.