Need help running tuning job on SLURM cluster with pytorch-lightning

Hi, I have a bit of experience running simple SLURM jobs on my school’s HPCC. I’m starting to use Raytune with my pytorch-lightning code and even though I’m reading documentation and stuff I’m still having a lot of trouble wrapping my head around things. I know the essence of Ray is that, given n nodes, you assign a single “head” node and n-1 “worker” nodes, and then supposedly Ray takes care of the rest. However I’m not sure how the parallelization in Ray is supposed to interact with the parallelization in pytorch and lightning.

Here is my SLURM script:

#!/bin/bash
#SBATCH -N 2 -n 8
#SBATCH --mem=32gb
#SBATCH --time=2-00:00:00
#SBATCH -p class -C gpu2080
#SBATCH --gres=gpu:2
#SBATCH --mail-user=email@email
#SBATCH --mail-type=ALL
#SBATCH --job-name="myjob"
#SBATCH --output=myjob.out

echo "Loading modules..."

module swap intel gcc
module load cuda/10.1

source ~/miniconda3/etc/profile.d/conda.sh
conda activate my_conda_env

cd ../../scripts
echo "Running python script..."
python my_python_script.py --epochs 200 --num_workers 4

and then the (simplified) python script:

import ...

os.environ['TUNE_DISABLE_AUTO_CALLBACK_LOGGERS'] = '1'
os.environ["SLURM_JOB_NAME"] = "bash"

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', nargs='?', default=200, type=int)
parser.add_argument('--num_workers', nargs='?', default=0, type=int)
args = parser.parse_args()

epochs = args.epochs
num_workers = args.num_workers

data_dir = os.path.abspath('../data')
log_dir = os.path.abspath('../logs/k-slope-betas-mnist')
mnist_class_dir = os.path.abspath('../model_caches/classifiers/mnist_val_net.ckpt')

config = {
    'k': tune.choice([0.2, 0.5, 1., 2., 5.]),
    'relu_slope': tune.choice([0., 1e-2, 0.1, 0.2]),
    'beta_1': tune.uniform(0.4, 0.99),
    'beta_2': tune.uniform(0.99, 0.9999)
}


def train_fn(config):
    # setup code, callbacks, logging, etc.

    trainer_args = {
        'max_epochs': epochs,
        'callbacks': [...],
        'progress_bar_refresh_rate': 0,
        'default_root_dir': log_dir,
        'automatic_optimization': False,
        'logger': logger
    }

    if cuda.device_count():
        trainer_args['gpus'] = 1

    t = Trainer(**trainer_args)

    t.fit(model)


gpus = 0
if cuda.device_count():
    gpus = 1

tune.run(
    tune.with_parameters(
        train_fn
    ),
    name='my_job',
    metric='primary_metric',
    config=config,
    local_dir=log_dir,
    resources_per_trial={'gpu': gpus, 'cpu': num_workers if num_workers > 0 else 1},
    num_samples=32
)

Any help would be greatly appreciated, I’m just really inexperienced with distributed computing. Thanks!

Hmm, I think you should probably disable the pytorch lightning SLURM + autoparallelism and use the Ray Lightning accelerators:

This should make it such that Ray handles the parallelism given the slurm allocation. There is also a tuning example there too!

cc @amogkam

1 Like

Thanks, that’s helpful! If I use a slurm script like the one here will that accomplish what I want?

yeah! that should work fine.

So would I do something like this, and then at the bottom where it says python -u simple-trainer.py I would run my python script?

Also, if I’m loading specific modules and activating my conda environment do I have to do that on every worker node?

I apologize for being pedantic, I just don’t want to mess something up… I’m just a humble graduate student :sweat_smile:

It looks like it doesn’t work with the up-to-date version of Lightning. I made an issue: Does not appear to be compatible with the current version of Lightning · Issue #8 · ray-project/ray_lightning_accelerators · GitHub

1 Like

So would I do something like this, and then at the bottom where it says python -u simple-trainer.py I would run my python script?

Yeah, I think that should be fine.

Hmm, I think you only need to do that once (as demonstrated in the documentation) before any sbatch commands are run? Though please let me know if it doesn’t work.

Sorry I must have missed that part of the documentation. Everything appears to be running fine now, though unfortunately Pytorch Lightning has a SLURM detection thing which prevents me from being able to use more than one node at once :frowning: I opened an issue, we’ll see if it gets fixed soon.