I have a tuning task using an imported parameter via argparse in trainable function. The task crashes complaining the argument is not provided. It works fine If I use it outside the trainable function. Any help is appreciated. The script being imported called “input_param.py”:
import sys, argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--ttt', type=int, required=True, help='anything > 1')
args = parser.parse_args()
ttt = args.ttt
The tuning task code is named as ‘example.py’:
import os
from ray import tune, air
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch
import input_param as input_param
def trainable(config):
#print('!! ttt = ', input_param.ttt)
score = config["a"] ** 2 + config["b"]
tune.report(SCORE=score)
search_space = {
"a": hp.uniform("a", 0, 1),
"b": hp.uniform("b", 0, 1)
}
raw_log_dir = "./ray_log"
raw_log_name = "example"
algorithm = HyperOptSearch(search_space, metric="SCORE", mode="max", n_initial_points=1)
tuner = tune.Tuner(trainable,
tune_config = tune.TuneConfig(
num_samples = 10,
search_alg=algorithm,
),
param_space=search_space,
run_config = air.RunConfig(local_dir = raw_log_dir, name = raw_log_name) #
)
print('!! ttt = ', input_param.ttt)
results = tuner.fit()
print(results.get_best_result(metric="SCORE", mode="max").config)
I run the task via the following command:
py example.py --ttt 99
The following is part of the error:
(pid=19560) default_worker.py: error: the following arguments are required: --ttt
(pid=19560) 2022-11-23 20:45:01,769 ERROR worker.py:763 -- Worker exits with an exit code 2.