Concurrency using ray.tune, slurm and BOHB

Hello, I’m using ray.tune in a server to optimize the performance of a PyTorch model, and I have a few questions. My search algorithm is BOHB
I’m using 8 nodes with each 4 GPU, which I managed to set as in Deploying on Slurm — Ray 1.11.0
I have in total 576 CPUs and 32 GPUs, and I’d like to run 8 concurrent trials.
My first question is: Does ray automatically use all the GPUs in a node, and not GPUs split among different nodes?

My second question is: When I check the log, one trial is marked as RUNNING, another is marked as PENDING and all other trials are marked as PAUSED/TERMINATED. Does it mean that I’m running only one concurrent trial? How can I change that so it uses the 8 nodes at the same time?

Best,
Pedro

EDIT: I Checked the log and it initially used the 8 nodes, and then started using only one.

Hi,
I am not super familiar with Ray on Slurm. But could you do a ray.cluster_resources? Just a sanity check on what is the resources visible to Ray in this setup.
As for your first question, maybe take a look at PlacementGroup concept in Ray. This dictates how tasks/actors are assigned resources and if resources are preferred to spread across multiple nodes or packed on the same nodes if possible.

As for 2nd question, I am suspecting the algorithm pauses all the trials in the same bucket while waiting for the straggler trial to catch up before making a decision about how to halve for the bucket. This could lead to the observation that only one trial is running and the rest are either TERMINATED or PAUSED (waiting for straggler to catch up).
Maybe you could try ASHA scheduling (which is supposed to deal with straggler issue better).

cc @Yard1

1 Like

@xwjiang2010 is right, this is how BOHB works. It will pause trials in a rung from time to time. Using ASHA with Optuna should give you similar performance to BOHB but without pausing.

Hello. Thanks a lot for your answers. You are right. I changed to ASHA and now it makes use of the nodes. But now I have two new problems: One small problem is that in some trials I have NCC2 error (it’s like 2%, probably because the trial was making use of hyperparameters resulting in a model that was too big, I assume).
And more important:
All trials are stopped after just one iteration. I assume this is not expected behavior.

This is the ray part of my code:

config = {
        "lr": tune.loguniform(1e-7, 1e-1),
        "p_dropout": tune.uniform(0, 0.45),
        "weight_decay": tune.loguniform(1e-9, 1e-1),
        "n_attn_blocks": tune.choice([4, 8, 16, 32]),
        "hidden_dim": tune.choice([1024, 2048, 4096]),
        "n_res": tune.choice([2, 3, 4, 5]),
        "batchn_1": tune.choice([True, False]),
        "batchn_2": tune.choice([True, False]),
        "batch_acc": tune.choice([2, 4, 8, 16]),
        "res_factor": tune.choice([2, 4, 8, 16]),
        "test_data_1": test_id_1,
        "test_data_2": test_id_2,
        "test_data_3": test_id_3,
        "train_data": train_id,
        "merged_data": merged_id,
        "uniprot": uniprot_id,
        "mets": mets_id,
        "trans_id":trans_id,
    }

    scheduler = AsyncHyperBandScheduler(metric="score", mode="max", grace_period=3, max_t=512)
    algo = SkOptSearch(metric="score", mode="max")
    result = tune.run(
        train_model,
        resources_per_trial={"cpu": 40, "gpu": 4},
        config=config,
        num_samples=64,
        scheduler=scheduler,
        search_alg=algo,
    )

My training loop consists of just one iteration, at the end I checkpoint and I report. Should it consist of several iterations with several reports and several checkpoints?

Best,
Pedro

EDIT: I’m going to change this in my code

Hello,
Yes, you would need to structure train_model function in a way that covers all the iterations.
Sorry this may come as a little confusing. We have two ways of supplying training logic:

  1. through function trainable (what you are using)
  2. through class trainable (user supplied per-step behavior)

In another word, you need to write something like:

for i in range(num_epoch):