Population based training does not train

I am performing a hyperparameter search on roberta-large using transformers trainer.hyperparameter_search API

here is my code snippet

num_cpus=1
num_gpus=1
keep_checkpoints_num=1
perturbation_interval=10
n_trials=16

def compute_perplexity(pred):
    logits = torch.from_numpy(pred.predictions)
    labels = torch.from_numpy(pred.label_ids)
    loss = cross_entropy(logits.view(-1, tokenizer.vocab_size), labels.view(-1))
    try:
        perplexity = math.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return {'perplexity': perplexity, 'calculated_loss': loss}
training_args = TrainingArguments(
    output_dir="/home/ubuntu/train/results",
    overwrite_output_dir = True,
    learning_rate=1e-5,                             # config
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model = 'eval_perplexity',
    greater_is_better = False,
    num_train_epochs=2,                             # config
    max_steps=-1,                                   # config
    per_device_train_batch_size=32,                 # config
    per_device_eval_batch_size=32,                  # config
    warmup_steps=0,                                 # config
    weight_decay=0.1,                               # config
    logging_dir="/home/ubuntu/train/logs",
    logging_steps=10,
    skip_memory_metrics=True,
    report_to="none",
    fp16=True,
    seed=12,                                        # config
)
trainer = Trainer(
    model_init=get_model,
    args=training_args,
    train_dataset=lm_datasets['train'].shard(index=1, num_shards=10),
    eval_dataset=lm_datasets['valid'].shard(index=1, num_shards=10),
    tokenizer = tokenizer,
    data_collator = data_collator,
    compute_metrics=compute_perplexity
)
tune_config = {
    "max_steps": -1,
    "per_device_eval_batch_size": 32,
}
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="eval_perplexity",
    mode="min",
    perturbation_interval=perturbation_interval,
    require_attrs=True,
    hyperparam_mutations={
        "learning_rate": tune.loguniform(1e-6, 1e-1),
        "num_train_epochs": tune.randint(1, 11),
        "per_device_train_batch_size": tune.randint(1, 33),
        "warmup_steps": tune.randint(1, 10001),
        "weight_decay": tune.uniform(1e-1, 0.6),
        "seed": tune.randint(1, 1000),
    })
def objective_perplexity(metrics):
    return metrics['eval_perplexity']
best_run = trainer.hyperparameter_search(
    hp_space=lambda _: tune_config,
    compute_objective=objective_perplexity,
    direction='minimize',
    backend="ray",
    n_trials=n_trials,
    resources_per_trial={
        "cpu": num_cpus,
        "gpu": num_gpus
    },
    scheduler=scheduler,
    keep_checkpoints_num=keep_checkpoints_num,
    checkpoint_score_attr="training_iteration",
    stop=None,
    local_dir="/home/ubuntu/train/ray_results",
    name="tune-clrp-mlm",
    log_to_file=True
)

output

+------------------------+------------+----------------------+-----------------+--------------------+-------------------------------+--------+----------------+----------------+-------------+
| Trial name             | status     | loc                  |   learning_rate |   num_train_epochs |   per_device_train_batch_size |   seed |   warmup_steps |   weight_decay |   objective |
|------------------------+------------+----------------------+-----------------+--------------------+-------------------------------+--------+----------------+----------------+-------------|
| _objective_ea634_00000 | TERMINATED | 104.171.200.69:56320 |     0.00687667  |                  1 |                            11 |    964 |           3271 |       0.300316 |     8.50084 |
| _objective_ea634_00001 | TERMINATED |                      |     1.03111e-05 |                 10 |                            23 |    481 |           9442 |       0.126476 |     8.92374 |
| _objective_ea634_00002 | TERMINATED |                      |     0.0430234   |                  5 |                             4 |    533 |           4588 |       0.568076 |     8.00639 |
| _objective_ea634_00003 | TERMINATED |                      |     1.49914e-06 |                  6 |                            20 |    787 |           5932 |       0.4181   |     8.77923 |
| _objective_ea634_00004 | TERMINATED |                      |     1.18166e-06 |                  7 |                            21 |     92 |            504 |       0.12751  |     8.40113 |
| _objective_ea634_00005 | TERMINATED |                      |     0.0105669   |                  7 |                             5 |    113 |           2870 |       0.386202 |     7.49462 |
| _objective_ea634_00006 | TERMINATED |                      |     2.09139e-05 |                  1 |                             4 |    951 |           1006 |       0.386697 |     6.14532 |
| _objective_ea634_00007 | TERMINATED |                      |     0.00034102  |                  2 |                             5 |    838 |           3738 |       0.402889 |     8.15622 |
| _objective_ea634_00008 | TERMINATED |                      |     1.46586e-05 |                  3 |                             5 |    949 |           2481 |       0.524008 |     8.56621 |
| _objective_ea634_00009 | TERMINATED |                      |     5.89949e-06 |                  7 |                            30 |    964 |           7410 |       0.143217 |     5.45729 |
| _objective_ea634_00010 | TERMINATED |                      |     1.1827e-06  |                  7 |                            18 |    460 |           3726 |       0.141262 | 1953.19     |
| _objective_ea634_00011 | TERMINATED |                      |     4.85341e-06 |                  1 |                            22 |     83 |           1129 |       0.591824 |     8.9625  |
| _objective_ea634_00012 | TERMINATED |                      |     0.0183518   |                  5 |                             4 |    448 |           7586 |       0.376019 |  1954.01    |
| _objective_ea634_00013 | TERMINATED |                      |     0.00264013  |                  7 |                            19 |    636 |            209 |       0.48228  |  1871.01    |
| _objective_ea634_00014 | TERMINATED |                      |     0.0101467   |                  1 |                             1 |    986 |           7279 |       0.435726 |  2231.21    |
| _objective_ea634_00015 | TERMINATED |                      |     4.34358e-05 |                  6 |                            12 |    694 |           7414 |       0.593281 |     8.38733 |
+------------------------+------------+----------------------+-----------------+--------------------+-------------------------------+--------+----------------+----------------+-------------+

All 16 trails run only once and then gets terminated, is the tuning even happening? It looks like it fired out 16 random search and picked the best out of those 16.

=========================
Discussion on Slack channel

@rliaw
What is your perturbation_interval set to?It seems like you’re tuning train epochs, and probably the training is finishing before the interval is hit

@rliaw
BTW, I would recommend posting to discuss.ray.io

@vinay_ethiraj
@rliaw Thank you for the reply, I will move the message to discuss.ray.io

=========================

Do I decrease the interval value and try tuning again?

@vinay_ethiraj yeah, try decreasing the interval to 1? and also, i wouldn’t tune the ‘train_epochs’ variable.

will update my findings here, I will move the num_train_epochs out of search space. Thank you for the reply :slight_smile:

Number of trials: 16/16 (14 PAUSED, 1 PENDING, 1 RUNNING)
+------------------------+----------+-------+-----------------+-------------------------------+--------+----------------+----------------+-------------+
| Trial name             | status   | loc   |   learning_rate |   per_device_train_batch_size |   seed |   warmup_steps |   weight_decay |   objective |
|------------------------+----------+-------+-----------------+-------------------------------+--------+----------------+----------------+-------------|
| _objective_39f1e_00000 | RUNNING  |       |     5.89949e-06 |                             7 |    260 |           3326 |       0.231658 |     9.5388  |
| _objective_39f1e_00002 | PAUSED   |       |     1.46929e-06 |                             3 |    385 |           8652 |       0.168605 |     9.37472 |
| _objective_39f1e_00003 | PAUSED   |       |     0.00239877  |                            13 |    959 |           4179 |       0.526368 |     8.79337 |
| _objective_39f1e_00004 | PAUSED   |       |     8.89983e-06 |                             2 |    353 |           5791 |       0.553634 |     7.36863 |
| _objective_39f1e_00005 | PAUSED   |       |     1.06798e-05 |                             1 |    423 |           1137 |       0.442907 |  1996.67    |
| _objective_39f1e_00006 | PAUSED   |       |     0.0101467   |                            14 |    111 |           4511 |       0.21168  |    13.1061  |
| _objective_39f1e_00007 | PAUSED   |       |     0.00227648  |                            26 |     26 |           4796 |       0.508084 |     8.17222 |
| _objective_39f1e_00008 | PAUSED   |       |     5.26616e-06 |                            25 |    578 |           3845 |       0.182184 |     8.51736 |
| _objective_39f1e_00009 | PAUSED   |       |     4.71235e-05 |                             6 |    837 |           7372 |       0.412291 |     7.48557 |
| _objective_39f1e_00010 | PAUSED   |       |     7.31494e-06 |                            22 |    481 |           3692 |       0.372732 |     8.94498 |
| _objective_39f1e_00011 | PAUSED   |       |     0.000180527 |                             7 |    157 |           6766 |       0.597569 |     6.96075 |
| _objective_39f1e_00012 | PAUSED   |       |     2.20231e-05 |                            28 |    533 |           4588 |       0.568076 |     8.52336 |
| _objective_39f1e_00013 | PAUSED   |       |     1.49914e-06 |                             6 |    161 |           8372 |       0.46626  |     8.61623 |
| _objective_39f1e_00014 | PAUSED   |       |     0.004259    |                            26 |    583 |           1783 |       0.210426 |     8.97854 |
| _objective_39f1e_00015 | PAUSED   |       |     0.000318775 |                            20 |    571 |            487 |       0.124109 |     7.70769 |
| _objective_39f1e_00001 | PENDING  |       |     0.00414963  |                            17 |    902 |            279 |       0.116982 |  1901.44    |
+------------------------+----------+-------+-----------------+-------------------------------+--------+----------------+----------------+-------------+

Now it trains beautifully! @rliaw thank you for the valuable inputs :slight_smile:

1 Like