Dear @kai,
unfortunately, I haven’t had the time to reproduce the errors, yet. However, I want to give you a better picture of what my current approach to parallelizing with Ray and Tune is. So on my local machine I run something similar to this with PyTorch:
def fit_hyper_parameter(model, optimizer, objective, device, idx):
"""Function containing the actual training loop"""
# Get specific training and test set based on some idx I specified in advance
train_loader, test_loader = get_inner_data_loaders(outer_inner_idx=idx)
for itr in range(1, (batches_per_epoch * 100) + 1):
# Perform training loop on one GPU
with torch.no_grad():
if itr % batches_per_epoch == 0:
# Compute some accuracy metric of test set and report to tune
tune.report(val_loss=accuracy_metric)
def init_model_selection(config, idx=None):
"""Function instantiating model, objective, optimizer"""
device = torch.device('cuda')
model = SomeModel(
hyper_param_01=config['hyper_param_01'],
hyper_param_02=config['hyper_param_02']
).to(device=device) # Send model to GPU for training
objective = objective_to_minimize.to(device)
optimizer = some_optmization_method()
# Call the training loop
fit_hyper_parameter(
model=model,
optimizer=optimizer,
objective=objective,
device=device,
idx=idx
)
@ray.remote
def run_tune(num_trials, idx):
"""Run a single Tune experiment"""
scheduler = ASHAScheduler(
metric='val_loss',
mode='min',
max_t=5,
grace_period=3,
reduction_factor=2
)
rng = np.random.default_rng(seed=12345)
config = {
'hyper_param_01': tune.sample_from(lambda _: rng.integers(a, b)),
'hyper_param_02': tune.sample_from(lambda _: rng.integers(a, b))
}
analysis = tune.run(
partial(init_model_selection, idx=idx),
scheduler=scheduler,
resources_per_trial={'cpu': 1, 'gpu': 1}, # One GPU per trial
num_samples=num_trials,
config=config,
)
# Return the idx for each specific data set in order to differentiate
# between them later on when evaluating the results
return [analysis.results_df, idx]
def main():
outer_num_folds = 10
inner_num_folds = 10
data, target = generate_data() # Full data set
split_serialize(
data=data,
target=target,
outer_num_folds=outer_num_folds,
inner_num_folds=inner_num_folds
)
idx = [tuple(i) for i in product(range(outer_num_folds), range(inner_num_folds))]
num_trials = 25 # Number of hyper-parameter configurations
select_ids = []
eval_ids = []
# A validation metric that each Tune experiment over a validation set returns
val_metric = np.zeros((num_trials, outer_num_folds))
# Start several Tune experiments in parallel
for i in range(len(idx)):
future_id = run_tune.remote(num_trials, idx[i])
select_ids.append(future_id)
# Wait until an experiment is finished and process the results
while select_ids:
done_ids, select_ids = ray.wait(select_ids)
val_metric = process_results(ray.get(done_ids[0]), val_metric)
This is able to run as expected on my local machine (32 core CPU, 3 GPUs). I get a lot of warnings, but in the end I get the results I would expect.