Hi, @matthewdeng Thx for reply. Can you tell me more details about your answer? (sorry for my poor English skills)
By the way, my codes are
- Main Codes
def run(self):
self._gridsearch(num_samples=10, max_num_epochs=10, gpus_per_trial=2)
# Clean up after finishing
self._clean_up()
def _gridsearch(self, num_samples, max_num_epochs, gpus_per_trial):
# Load DataSet
self._load_dataset()
feature_dim = self.stock_data_reader.get_input_shape()
# Model
self.model = model_REGISTRY[self.config.model](
feature_dim, self.config, self.tune_config
).to(self.config.device)
# Denoising
self.denoising_model = denoising_REGISTRY[self.config.denoising](
self.config
).to(self.config.device)
if self.config.checkpoint_path != "" and self._load_model() is False:
return
# Test Mode
if not self.config.training_mode:
self.test()
return
# Learner
self.learner = learner_REGISTRY[self.config.learner](
self.denoising_model, self.model, self.logger, self.config, self.tune_config
)
if self.config.use_cuda:
self.learner.cuda()
scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
grace_period=1,
reduction_factor=2,
)
reporter = CLIReporter(
metric_columns=["loss", "training_iteration"]
)
result = tune.run(
tune.with_parameters(self.train),
tune.with_parameters(self.learner),
tune.with_parameters(self.model),
tune.with_parameters(self.stock_data_reader),
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
config=self.tune_config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter,
)
~~~
~~~
- The codes where error occurs (stock_data_reader.py)
class PriceDataReader:
def __init__(self, logger, config, tune_config):
self.logger = logger
self.config = config
self.lag = config.lag # pct_change
self.batch_size = config.batch_size
self.sequence_length = config.sequence_length
if config.tune:
self.batch_size = int(tune_config["batch_size"])
self.sequence_length = int(tune_config["sequence_length"])
~~~
~~~
Currently default parameters are setting by config files.
However, during using ray-tune, I want to override parameters by tune_config like above code.
Please tell me if you need more information about my code.
Again, Thx for reply and reading my comments.