Hello, I have a pytorch lightning model whose hyper parameters are handled by hydra config.
These configs are organised in different folders as hydra makes these easy to manage. This is the template for my main config.
defaults:
- _self_
- trainer: default_trainer
- training: default_training
- model: default_model
- data: default_data
- augmentation: default_augmentation
- transformation: default_transformation
general:
log_dir: outputs
inference_checkpoint: path/to/checkpoint
For each section in the defaults list there are other yaml files containing the various hyper parameters. For example, the training
one is in conf/training/default_training.yaml
and it reads:
a_w: 1.0
b_w: 1.0
c_w: 1.0
d_w: 1.0
# etc...
My (ultra-simplified) training script looks like:
def train_model(cfg):
model = MyLightningModel(cfg)
trainer = Trainer(**cfg.trainer)
trainer.fit(model)
@hydra.main(config_path='conf', config_name='config')
def main(cfg: DictConfig):
train_model(cfg)
if __name__ == '__main__':
main()
Now, what I’d love to do is to use Ray Tune to find optimal values for a_w
, b_w
, c_w
, and d_w
. By reading this guide, I managed to change the training script as:
def train_model(cfg):
model = MyLightningModel(cfg)
tune_callback = TuneReportCallback({"loss": "val/avg_loss"}, on="validation_end")
trainer = Trainer(**cfg.trainer, callbacks=[tune_callback])
trainer.fit(model)
@hydra.main(config_path='conf', config_name='config')
def main(cfg: DictConfig):
scheduler = ASHAScheduler(max_t=10, grace_period=1, reduction_factor=2)
reporter = CLIReporter(metric_columns=["dIOUa", "training_iteration"],
parameter_columns=["rmse_w"])
cfg.training.a_w = tune.uniform(0.0, 1.0)
cfg.training.b_w = tune.uniform(0.0, 1.0)
cfg.training.c_w = tune.uniform(0.0, 1.0)
cfg.training.d_w = tune.uniform(0.0, 1.0)
analysis = tune.run(tune.with_parameters(train_hydranet,
num_epochs=10,
num_gpus=0),
resources_per_trial={
"cpu": 1,
"gpu": 0
},
metric="loss",
mode="min",
config=cfg,
num_samples=10,
scheduler=scheduler,
progress_reporter=reporter,
name="tune_mymodel")
print("Best hyperparameters found were: ", analysis.best_config)
if __name__ == '__main__':
main()
This however raises an exception
Traceback (most recent call last):
cfg.training.a_w = tune.uniform(0.0, 1.0)
omegaconf.errors.UnsupportedValueType: Value 'Float' is not a supported primitive type
full_key: training.a_w
object_type=dict
This is not super clear to me, is it telling me that I cannot assign a Float
value to a dict
?
What would the proper way of using RayTune with hydra? Changing the way configs are handled is not really an option as we heavily rely on hydra at this point (I may need to rewrite most of the lightning module).
Thanks a lot!