[AIR] [TUNE] Custom hyperparameter constraints

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi,

We are trying to use Ray Tune for Neaural Architecture Search. The hyperparameters are the number of different convolutional/FC/batchnorm/etc. layers, the kernel sizes etc., all in all we have 10-20 hyperparameters. As the resulting models should work efficiently on mobile phones, they need to be small, therefore we should be able to set a constraint on the hyperparameters so that the overall model size is below a certain threshold.

How can we implement the following logic?

  1. Let Ray Tune select the hyperparameters from the given hyperparameter space (this is pretty straightforward)
  2. Check with a custom defined function whether the number of parameters of the assembled model is below the threshold, e.g.:
model = MyModel(hyperparameters)
if model.parameter_count() < threshold: ...
  • Case-A: If the model size is below the threshold → start training the model (the trial)
  • Case-B: Otherwise → stop the trial immediately without training the model

In Case-B, the following two logic should also apply:

  1. num_samples (number of times to sample from the hyperparameter space) we set in TuneConfig should not take into account these trials. This is important because if in 99/100 (num_samples=100) cases Ray Tune selects hyperparameters so that the model is too large, we would end up with only 1 trained model.
  2. The Search Algorithm should not take into account these trials when it next samples the hyperparameter space for upcoming trials.

Hi @pezosanta,

You’ll want to subclass the Stopper interface and use it together with TuneConfig(num_samples=-1) to achieve the desired behavior. See TuneConfig docs here: Tune Execution (tune.Tuner) — Ray 3.0.0.dev0

  • The searcher will continuously generate model configurations, and if you set num_samples=-1, it will continue generating configs indefinitely until a stopping criteria is met.

  • The stopping criteria can be implemented in many ways, but one possibility is using a Stopper that terminates all trials once a certain number of trials have completed with the desired model size.

Here’s some code to get you started:

import numpy as np
import time

from ray import air, tune
from ray.air import session


def train_fn(config):
    if config["model_size"] < 5:
        # Throw this trial away -- marked with a nan
        session.report({"loss": np.nan, "done": True})
        return

    # Do model training
    for i in range(10):
        time.sleep(1)
        session.report({"loss": loss, "done": False})

    loss = 1
    session.report({"loss": loss, "done": True})


class NumValidTrialsStopper(tune.Stopper):
    def __init__(self, num_trials=100):
        self.valid_trials = set()
        self.num_trials = num_trials

    def __call__(self, trial_id: str, result: dict) -> bool:
        if not np.isnan(result["loss"]) and result["done"]:
            self.valid_trials.add(trial_id)

    def stop_all(self) -> bool:
        return len(self.valid_trials) >= self.num_trials


# You can also use a custom searcher
tuner = tune.Tuner(
    tune.with_resources(train_fn, {"CPU": 2.0}),
    param_space={
        "grid_search_param": tune.grid_search([1, 2, 3]),
        "model_size": tune.choice(list(range(0, 10))),
    },
    # num_samples=-1 will generate infinitely many samples,
    # until a stopping criterion is met
    tune_config=tune.TuneConfig(num_samples=-1),
    run_config=air.RunConfig(stop=NumValidTrialsStopper(10)),
)
results = tuner.fit()

# Tune will queue up more trials that don't get to run - filter these out
valid_results = [
    result for result in results if not np.isnan(result.metrics.get("loss", np.nan))
]

assert len(valid_results) == 10

For reference, here’s the conversation on the Ray slack: Slack

For anyone who is interested, this solution works for AxSearch and any other Search Algorithms that abandon trials with np.nan (or any other default value) metric reports.

This feature is dependent on the external library right now. Here are the algos that may work the same way as AxSearch with np.nan metric reports:

@justinvyu, thanks once again.