Argparse, import and trainable

I have a tuning task using an imported parameter via argparse in trainable function. The task crashes complaining the argument is not provided. It works fine If I use it outside the trainable function. Any help is appreciated. The script being imported called “input_param.py”:

import sys, argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('--ttt', type=int, required=True, help='anything > 1')
args = parser.parse_args()

ttt = args.ttt

The tuning task code is named as ‘example.py’:

import os
from ray import tune, air
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch
import input_param as input_param

def trainable(config):
    #print('!! ttt = ', input_param.ttt)
    score = config["a"] ** 2 + config["b"]
    tune.report(SCORE=score)


search_space = {
    "a": hp.uniform("a", 0, 1),
    "b": hp.uniform("b", 0, 1)
    }

raw_log_dir = "./ray_log"
raw_log_name = "example"

algorithm = HyperOptSearch(search_space, metric="SCORE", mode="max", n_initial_points=1)


tuner = tune.Tuner(trainable,
        tune_config = tune.TuneConfig(
            num_samples = 10,
            search_alg=algorithm,
            ),
        param_space=search_space,
        run_config = air.RunConfig(local_dir = raw_log_dir, name = raw_log_name) #
        )

print('!! ttt = ', input_param.ttt)
results = tuner.fit()
print(results.get_best_result(metric="SCORE", mode="max").config)

I run the task via the following command:

 py example.py --ttt 99

The following is part of the error:

(pid=19560) default_worker.py: error: the following arguments are required: --ttt
(pid=19560) 2022-11-23 20:45:01,769     ERROR worker.py:763 -- Worker exits with an exit code 2.

Seems more like a tune problem than RLlib :smiley:

cc @kai

1 Like

@wxie2013
I would recommend to have parser only in the driver process and in trainable definition you can refer to the variable ttt. So it would be something like:

ttt = args.ttt
def trainable(config):
  print(f"ttt = {ttt}")
   ...
tuner=tune.tuner(trainable, ...)
tuner.fit()
1 Like