Passing fixed arguments to objective fucntion

Hi,
My code is as follows


from ray import tune


algo = BayesOptSearch(utility_kwargs={'kind':'ucb', 'kappa':2.5, 'xi':0.0})
num_samples = 5

tuner = tune.run(
                        train,
                        tune_config = tune.TuneConfig(
                                metric='val_loss',
                                mode='min',
                                search_alg = algo,
                                num_samples = num_samples
                                ),
                        param_space=search_space
                        )

def train(config):
..... 

I also want to pass my data location to the function ‘train’ because it contains the data loading logic. How do I do it ?
I tried the following but didn’t work.

from functools import partial

tuner = tune.run(
                        partial(train, data_dir=<data location>),
                        tune_config = tune.TuneConfig(
                                metric='val_loss',
                                mode='min',
                                search_alg = algo,
                                num_samples = num_samples
                                ),
                        param_space=search_space
                        )

def train(config, data_dir=None):
..... 

I found the solution.
Simply changed

def train(config, data_dir=None)

to just

def train(config, data_dir)

Hey circa,

Thanks for posting to the forum!

Looks like you already figured it out, but just wanted to add two other solutions:

  1. Use tune.with_parameters

tuner = tune.run(
    tune.with_parameters(train, data_dir=...)
    tune_config = tune.TuneConfig(
        metric='val_loss',
        mode='min',
        search_alg = algo,
        num_samples = num_samples
    ),
    param_space=search_space
)

def train(config, data_dir=None):
    ...
  1. Use param_space
search_space["data_dir"] = ...

tuner = tune.run(
    train,
    tune_config = tune.TuneConfig(
        metric='val_loss',
        mode='min',
        search_alg = algo,
        num_samples = num_samples
    ),
    param_space=search_space
)

def train(config):
    data_dir = config["data_dir"]