How can I use Ray Tune + DDP PyTorch (not PyTorch Lightning)? I don’t see an example. Please review my code if you need context; everything works, but I’m unable to verify if it’s DDP training. I test it in Kaggle/Colab; I don’t have GPUS.
import socket
hostname = socket.gethostname()
print(f"Current hostname: {hostname}")
if hostname != “macbook-air-m2”:
!pip install -q “ray[tune, train]”
!pip install -q optuna torchmetrics
!pip install -qU tensorboardx
else:
print(“Running on local machine, skipping package installation.”)import os
from itertools import count
from time import timeimport matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import ray
import json
import ray.tune
import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from ray import tune
from ray.air.result import Result as AirResult
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from tabulate import tabulate
from torch.utils.data import DataLoader, random_split
from torchmetrics import MetricCollection
from torchmetrics.classification import (AUROC, EER, ROC, Accuracy,
AveragePrecision, CalibrationError,
CohenKappa, ConfusionMatrix,
ExactMatch, F1Score, MatthewsCorrCoef,
NegativePredictiveValue, Precision,
PrecisionRecallCurve, Recall,
Specificity)
from torchvision import datasets
from torchvision.transforms import v2os.environ[“RAY_AIR_NEW_OUTPUT”] = “0”
os.environ[“RAY_TRAIN_V2_ENABLED”] = “1”
os.environ[“PYTORCH_ENABLE_MPS_FALLBACK”] = “1”==============================================================
Model
==============================================================
class Net(nn.Module):
def init(self, l1, l2):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x==============================================================
Custom Loss Functions
==============================================================
class CustomLoss(nn.Module):
def forward(self, outputs, targets):
raise NotImplementedErrorclass NLLLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.NLLLoss()
self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, outputs, targets): return self.criterion(self.log_softmax(outputs), targets)class KLDivLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.KLDivLoss(reduction=“batchmean”)
self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, outputs, targets): outputs = self.log_softmax(outputs) targets = F.one_hot(targets, num_classes=outputs.shape[1]).float() targets = targets / targets.sum(dim=1, keepdim=True) # normalisation return self.criterion(outputs, targets)class MSELossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.MSELoss()def forward(self, outputs, targets): outputs = torch.sigmoid(outputs) targets = F.one_hot(targets, num_classes=10).float() return self.criterion(outputs, targets)class MultiMarginLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.MultiMarginLoss()def forward(self, outputs, targets): return self.criterion(outputs, targets)class FocalLossV2(CustomLoss):
def init(self, alpha=1, gamma=2):
super().init()
self.alpha = alpha
self.gamma = gammadef forward(self, inputs, targets): BCE_loss = F.cross_entropy(inputs, targets, reduction="none") pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss return F_loss.mean()==============================================================
Loaders
==============================================================
def load_device() → torch.device:
if torch.backends.mps.is_available():
device = torch.device(“mps”)
elif torch.cuda.is_available():
device = torch.device(“cuda”)
else:
device = torch.device(“cpu”)
return devicedef load_model() → dict:
models = {
“Net”: Net,
}
return modelsdef load_loss() → dict:
losses = {
“CrossEntropyLoss”: nn.CrossEntropyLoss(),
“NLLLoss”: NLLLossV2(),
“KLDivLoss”: KLDivLossV2(),
“MSELoss”: MSELossV2(),
“MultiMarginLoss”: MultiMarginLossV2(),
“FocalLoss”: FocalLossV2(),
}
return lossesdef load_optimizer() → dict:
optimizers = {
“Adam”: optim.Adam,
“AdamW”: optim.AdamW,
“SGD”: optim.SGD,
“Adadelta”: optim.Adadelta,
“Adagrad”: optim.Adagrad,
“Adamax”: optim.Adamax,
“RMSprop”: optim.RMSprop,
“Rprop”: optim.Rprop,
}
return optimizersdef load_scheduler() → dict:
schedulers = {
“StepLR”: optim.lr_scheduler.StepLR,
“MultiStepLR”: optim.lr_scheduler.MultiStepLR,
“ExponentialLR”: optim.lr_scheduler.ExponentialLR,
“CosineAnnealingLR”: optim.lr_scheduler.CosineAnnealingLR,
“ReduceLROnPlateau”: optim.lr_scheduler.ReduceLROnPlateau,
}
return schedulersdef metric_collection(
task: str, num_classes: int
) → MetricCollection:
metrics = MetricCollection(
{
“Accuracy”: Accuracy(task=task, num_classes=num_classes),
“AUROC”: AUROC(task=task, num_classes=num_classes),
“AveragePrecision”: AveragePrecision(task=task, num_classes=num_classes),
“CalibrationError”: CalibrationError(task=task, num_classes=num_classes),
“CohenKappa”: CohenKappa(task=task, num_classes=num_classes),
“ConfusionMatrix”: ConfusionMatrix(task=task, num_classes=num_classes),
“EER”: EER(task=task, num_classes=num_classes),
“ExactMatch”: ExactMatch(task=task, num_classes=num_classes),
“Precision_macro”: Precision(
average=“macro”, task=task, num_classes=num_classes
),
“Recall_macro”: Recall(average=“macro”, task=task, num_classes=num_classes),
“F1Score_macro”: F1Score(
average=“macro”, task=task, num_classes=num_classes
),
“NegativePredictiveValue”: NegativePredictiveValue(
task=task, num_classes=num_classes
),
“MatthewsCorrCoef”: MatthewsCorrCoef(task=task, num_classes=num_classes),
“ROC”: ROC(task=task, num_classes=num_classes),
“Specificity”: Specificity(task=task, num_classes=num_classes),
“PrecisionRecallCurve”: PrecisionRecallCurve(
task=task, num_classes=num_classes
)
}
)
return metrics==============================================================
Data Loaders
==============================================================
def get_data_loaders(config):
data_dir = config.get(“data_dir”)
batch_size = config.get(“batch_size”)
world_size = (
torch.cuda.device_count()
if torch.cuda.is_available()
else (os.cpu_count() or 1)
)
worker_batch_size = max(1, batch_size // world_size)os.makedirs(data_dir, exist_ok=True) transform = v2.Compose( [ v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) with FileLock(os.path.join(data_dir, ".data.lock")): train_dataset = datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform ) test_val_dataset = datasets.CIFAR10( root=data_dir, train=False, download=True, transform=transform ) val_dataset, test_dataset = random_split(test_val_dataset, [0.5, 0.5]) train_loader = DataLoader(train_dataset, batch_size=worker_batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=worker_batch_size) test_loader = DataLoader(test_dataset, batch_size=worker_batch_size) loaders = (train_loader, val_loader, test_loader) return loaders==============================================================
Training / Validation / Testing functions
==============================================================
def train_epoch(model, criterion, optimizer, train_loader, device, use_amp):
batch_train_loss, batch_train_accuracy = 0.0, 0
if use_amp:
scaler = torch.amp.GradScaler(device=device.type)
model.train()
for data, target in train_loader:
data, target = data.to(device, non_blocking=True), target.to(
device, non_blocking=True
)
optimizer.zero_grad()
if use_amp:
with torch.amp.autocast(device_type=device.type):
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()batch_train_loss += loss.item() pred = output.argmax(dim=1) batch_train_accuracy += pred.eq(target).sum().item() batch_train_loss /= len(train_loader) batch_train_accuracy *= 100 / len(train_loader.dataset) train_metrics = { "train_loss": batch_train_loss, "train_accuracy": batch_train_accuracy, } return train_metrics@torch.inference_mode()
def val_epoch(model, criterion, val_loader, device, use_amp):
batch_val_loss, batch_val_accuracy = 0.0, 0
model.eval()
for data, target in val_loader:
data, target = data.to(device, non_blocking=True), target.to(
device, non_blocking=True
)
if use_amp:
with torch.amp.autocast(device_type=device.type):
output = model(data)
loss = criterion(output, target)
else:
output = model(data)
loss = criterion(output, target)batch_val_loss += loss.item() pred = output.argmax(dim=1) batch_val_accuracy += pred.eq(target).sum().item() batch_val_loss /= len(val_loader) batch_val_accuracy *= 100 / len(val_loader.dataset) val_metrics = { "val_loss": batch_val_loss, "val_accuracy": batch_val_accuracy, } return val_metrics@torch.inference_mode()
def test_model(best_result: AirResult, plot_dir: str):
device = load_device()
models = load_model()
losses = load_loss()
best_config = best_result.config
checkpoint_path = os.path.join(best_result.checkpoint.path, “checkpoint.pt”)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)model = models[best_config["model"]]( best_config.get("l1"), best_config.get("l2"), ).to(device) model_state = checkpoint["model"] model.load_state_dict(model_state) print(f"Model device: {next(model.parameters()).device}") _, _, test_loader = get_data_loaders(config=best_result.config) labels = test_loader.dataset.dataset.classes num_classes = len(labels) task = "multiclass" if num_classes > 2 else "binary" metrics = metric_collection(task=task, num_classes=num_classes).to(device) criterion = losses[best_config["loss"]] test_loss, test_accuracy = 0.0, 0 use_amp = best_config.get("use_amp", False) metrics.reset() model.eval() start_time = time() for data, target in test_loader: data, target = data.to(device, non_blocking=True), target.to( device, non_blocking=True ) if use_amp: with torch.amp.autocast(device_type=device.type): output = model(data) loss = criterion(output, target) else: output = model(data) loss = criterion(output, target) test_loss += loss.item() pred = output.argmax(dim=1) test_accuracy += pred.eq(target).sum().item() metrics.update(output, target) end_time = time() - start_time test_loss /= len(test_loader) test_accuracy *= 100 / len(test_loader.dataset) keys_plot = ["ConfusionMatrix", "ROC", "PrecisionRecallCurve"] test_metrics = { "Loss": round(test_loss, 6), "Accuracy (%)": round(test_accuracy, 2), "Time (s)": round(end_time, 2), } scores = metrics.compute() eer = scores.get("EER").detach().cpu().numpy().tolist() scores["EER"] = [round(v, 6) for v in eer] filtered_results = { k: ( v if k == "EER" else round(v.item(), 6) ) for k, v in scores.items() if k not in keys_plot } test_metrics.update(filtered_results) metric_plot = MetricCollection({k: metrics[k] for k in keys_plot}).to(device) for k, v in metric_plot.items(): fig, ax = plt.subplots(figsize=(8, 8)) if k == "ROC": v.plot(score=True, labels=labels, ax=ax) elif k == "PrecisionRecallCurve": v.plot(score=True, ax=ax) else: v.plot(labels=labels, ax=ax) ax.set_aspect("equal", adjustable="box") ax.set_title(k, fontsize=10) handles, labels_ = ax.get_legend_handles_labels() if handles: ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5), fontsize=8) fig.subplots_adjust(right=0.75) plt.savefig(os.path.join(plot_dir, f"{k}.svg"), bbox_inches="tight") plt.show() return test_metrics=============================================================
Utils
=============================================================
def format_train_metrics(metrics: dict, path: str) → str:
rename_map = {
“train_loss”: “Train Loss”,
“train_accuracy”: “Train Accuracy (%)”,
“val_loss”: “Validation Loss”,
“val_accuracy”: “Validation Accuracy (%)”,
}formatted = {} ordered_keys = [] for key in ["train_loss", "train_accuracy", "val_loss", "val_accuracy"]: if key in metrics: name = rename_map[key] val = metrics[key] if isinstance(val, float): if "loss" in key: val = round(val, 6) elif "accuracy" in key: val = round(val, 2) formatted[name] = val ordered_keys.append(name) if "time_total_s" in metrics and "training_iteration" in metrics: avg_time = metrics["time_total_s"] / metrics["training_iteration"] formatted["Average time per epoch (s)"] = round(avg_time, 2) ordered_keys.append("Average time per epoch (s)") for key, val in metrics.items(): if key in ["config", "experiment_tag"] or key in rename_map: continue if isinstance(val, float): val = round(val, 2) formatted[key] = val ordered_keys.append(key) with open(path, "w") as f: json.dump(formatted, f, indent=4) rows = [(k, formatted[k]) for k in ordered_keys] table = tabulate(rows, headers=["Metric", "Value"], tablefmt="fancy_grid") return tabledef format_test_metrics(metrics: dict, save_path: str) → str:
formatted = {}for key, val in metrics.items(): if key.lower() == "accuracy": continue elif key == "Accuracy (%)": formatted["Accuracy (%)"] = round(val, 2) elif isinstance(val, float): formatted[key] = round(val, 6) else: formatted[key] = val with open(save_path, "w") as f: json.dump(formatted, f, indent=4) rows = list(formatted.items()) table = tabulate(rows, headers=["Metric", "Value"], tablefmt="fancy_grid") return table==============================================================
Plotting
==============================================================
def plot_loss_curves(df: pd.DataFrame, loss_path: str, title: str = “Loss curves”):
figsize = (16, 10)
col_loss = [“train_loss”, “val_loss”]fig, ax = plt.subplots(figsize=figsize) df.plot( x="training_iteration", y=col_loss, marker="v", markersize=6, linestyle="-", linewidth=1, ax=ax, title=title, ylabel="Loss", xlabel="Training Iteration", ) xmin = int(df["training_iteration"].min()) xmax = int(df["training_iteration"].max()) ax.set_xlim(left=max(1, xmin), right=xmax) locator = ticker.MaxNLocator(nbins=12, integer=True) ax.xaxis.set_major_locator(locator) xticks = list(ax.get_xticks()) if 1 not in xticks: xticks = [1] + [t for t in xticks if t > 1] ax.set_xticks(xticks) ymin = 0 ymax = df[col_loss].max().max() ax.set_ylim(bottom=ymin, top=ymax) yticks = np.linspace(ymin, ymax, num=10).tolist() if ymax not in yticks: yticks.append(ymax) yticks = sorted(set(yticks)) ax.set_yticks(yticks) ax.yaxis.set_major_formatter( ticker.FuncFormatter(lambda x, _: f"{int(x)}" if x.is_integer() else f"{x:.2f}") ) ax.legend(col_loss) fig.tight_layout() fig.savefig(loss_path)def plot_accuracy_curves(
df: pd.DataFrame, acc_path: str, title: str = “Accuracy curves”
):
figsize = (16, 10)
col_acc = [“train_accuracy”, “val_accuracy”]fig, ax = plt.subplots(figsize=figsize) df.plot( x="training_iteration", y=col_acc, marker="^", markersize=6, linestyle="-", linewidth=1, ax=ax, title=title, ylabel="Accuracy (%)", xlabel="Training Iteration", ) xmin = int(df["training_iteration"].min()) xmax = int(df["training_iteration"].max()) ax.set_xlim(left=max(1, xmin), right=xmax) locator = ticker.MaxNLocator(nbins=12, integer=True) ax.xaxis.set_major_locator(locator) xticks = list(ax.get_xticks()) if 1 not in xticks: xticks = [1] + [t for t in xticks if t > 1] if xmax not in xticks: xticks = [t for t in xticks if t < xmax] + [xmax] ax.set_xticks(xticks) ymin = df[col_acc].min().min() ymax = 100 ax.set_ylim(ymin, ymax) step = 5 yticks = list(range(int(np.floor(ymin)), ymax + 1, step)) if ymax not in yticks: yticks.append(ymax) ax.set_yticks(yticks) ax.legend(col_acc) fig.tight_layout() fig.savefig(acc_path)def plot_all_trials_accuracy(result_grid: ray.tune.ResultGrid, trials_acc_path: str):
figsize = (16, 10)
fig, ax = plt.subplots(figsize=figsize)all_vals = ( np.concatenate( [ r.metrics_dataframe["val_accuracy"].values # type: ignore for r in result_grid if r.metrics_dataframe is not None and "val_accuracy" in r.metrics_dataframe ] ) if result_grid else np.array([]) ) xmax = max([r.metrics_dataframe["training_iteration"].max() for r in result_grid]) xmin = 1 xlim = (xmin, xmax) xticks = np.linspace(xmin, xmax, num=12, dtype=int) if 1 not in xticks: xticks = np.insert(xticks, 0, 1) if xmax not in xticks: xticks = np.append(xticks, xmax) ymin = int(all_vals.min()) if len(all_vals) > 0 else 0 ymax = 100 ylim = (ymin, ymax) yticks = list(range(ymin, ymax + 1, 5)) if ymax not in yticks: yticks.append(ymax) configs = [] for idx, result in enumerate(result_grid): df = result.metrics_dataframe if df is not None and "val_accuracy" in df: label = "".join([f"{k}={v}, \n" for k, v in result.config.items()]) df.plot( x="training_iteration", y="val_accuracy", ax=ax, label=label, marker="o", markersize=3, linestyle="-", linewidth=1, ) configs.append({"trial": f"Trial {idx+1}", **result.config}) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.set_xticks(xticks) ax.set_yticks(yticks) ax.set_ylabel("Test Accuracy") ax.set_xlabel("Training Iteration") ax.set_title("Accuracy for All Trials") handles, labels = ax.get_legend_handles_labels() if handles: legend_fig, legend_ax = plt.subplots(figsize=(4, len(labels) * 0.3)) legend_ax.axis("off") legend_ax.legend( handles, labels, loc="center", fontsize=7, ncol=1, ) legend_fig.savefig( trials_acc_path.replace(".svg", "_legend.svg"), bbox_inches="tight", ) plt.close(legend_fig) leg = ax.get_legend() if leg: leg.remove() fig.savefig(trials_acc_path, bbox_inches="tight") pd.DataFrame(configs).to_csv(trials_acc_path.replace(".svg", ".csv"), index=False)==============================================================
Ray Trainable
==============================================================
class TrainCnn(tune.Trainable):
def setup(self, config):
self.device = load_device()
self.use_amp = config.get(“use_amp”, False)
models = load_model()
loss = load_loss()
optimizer = load_optimizer()
self.data_dir = config.get(“data_dir”)
self.batch_size = config.get(“batch_size”)
self.epoch = config.get(“epoch”)
self.model = models[config[“model”]](
config.get(“l1”),
config.get(“l2”),
).to(self.device)self.criterion = loss[config["loss"]] self.optimizer = optimizer[config["optimizer"]]( self.model.parameters(), lr=config.get("lr") ) self.train_loader, self.val_loader, self.test_loader = get_data_loaders(config) def step(self): # ✅ Check DDP if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): print("Model is wrapped with DistributedDataParallel") else: print("Model is NOT in DDP") # ✅ Check DistributedSampler train_sampler = getattr(self.train_loader, "sampler", None) if isinstance(train_sampler, torch.utils.data.distributed.DistributedSampler): print("Train loader uses DistributedSampler") else: print("Train loader does NOT use DistributedSampler") print("Model device:", next(self.model.parameters()).device) train_metrics = train_epoch( model=self.model, criterion=self.criterion, optimizer=self.optimizer, train_loader=self.train_loader, device=self.device, use_amp=self.use_amp, ) val_metrics = val_epoch( model=self.model, criterion=self.criterion, val_loader=self.val_loader, device=self.device, use_amp=self.use_amp, ) metrics = {**train_metrics, **val_metrics} return metrics def save_checkpoint(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt") torch.save( { "epoch": self.epoch, "model": self.model.module.state_dict() if hasattr(self.model, "module") else self.model.state_dict(), "loss": self.criterion.__class__.__name__, "optimizer": self.optimizer.state_dict(), }, checkpoint_path, ) def load_checkpoint(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt") checkpoint = torch.load(checkpoint_path, map_location=self.device) self.epoch = checkpoint["epoch"] model_state = checkpoint["model"] self.model.load_state_dict(model_state) loss = load_loss() self.criterion = loss[checkpoint["loss"]] optimizer_state = checkpoint["optimizer"] self.optimizer.load_state_dict(optimizer_state) def reset_config(self, new_config): self.epoch = new_config.get("epoch") self.data_dir = new_config.get("data_dir") self.batch_size = new_config.get("batch_size") self.train_loader, self.val_loader, self.test_loader = get_data_loaders( new_config ) models = load_model() loss = load_loss() optimizer = load_optimizer() self.model = models[new_config["model"]]( new_config.get("l1"), new_config.get("l2"), ).to(self.device) self.criterion = loss[new_config["loss"]] self.optimizer = optimizer[new_config["optimizer"]]( self.model.parameters(), lr=new_config.get("lr") ) self.config = new_config return Truedef custom_trial_name_creator(trial):
return f"trial_{trial.trial_id}"def custom_trial_dirname_creator(trial):
return f"trial_{trial.trial_id}"def main():
experience = “cifar10”
work_dir = “/kaggle/working/laboai/results/”
data_dir = os.path.join(work_dir, experience, “data”)
result_dir = os.path.join(work_dir, experience, “results”)
plot_dir = os.path.join(work_dir, experience, “plots”)
os.makedirs(result_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)
result_path = os.path.join(result_dir, “results.csv”)
loss_path = os.path.join(plot_dir, “loss.svg”)
accuracy_path = os.path.join(plot_dir, “accuracy.svg”)
trials_acc_path = os.path.join(plot_dir, “trials_accuracy.svg”)lr = [round(10**-i, 10) for i in range(7)] + [ round(3 * 10**-i, 10) for i in range(1, 7) ] epochs = 2 num_samples = 2 cpus = os.cpu_count() gpus = getattr(torch, load_device().type).device_count() if load_device() else 0 cpu_per_trial = cpus gpu_per_trial = gpus trainable = tune.with_resources( TrainCnn, resources={"cpu": float(cpu_per_trial), "gpu": float(gpu_per_trial)} ) models = load_model() loss = load_loss() optimizer = load_optimizer() param_space = { "data_dir": data_dir, "use_amp": tune.choice([False]), "batch_size": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14)), "model": tune.choice(list(models.keys())), "lr": tune.choice(lr), "loss": tune.choice(list(loss.keys())), "optimizer": tune.choice(list(optimizer.keys())), "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14)), "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14)) } search_alg = OptunaSearch( metric="val_accuracy", mode="max", ) scheduler = ASHAScheduler( time_attr="training_iteration", max_t=epochs, grace_period=epochs, reduction_factor=3, brackets=3 ) max_concurrent = int(max(1, cpus // cpu_per_trial)) tune_config = tune.TuneConfig( mode="max", metric="val_accuracy", search_alg=search_alg, scheduler=scheduler, num_samples=num_samples, max_concurrent_trials=max_concurrent, reuse_actors=True, trial_name_creator=custom_trial_name_creator, trial_dirname_creator=custom_trial_dirname_creator ) run_config_checkpoint_config = tune.CheckpointConfig( checkpoint_at_end=True, checkpoint_frequency=4, checkpoint_score_order="max", checkpoint_score_attribute="val_accuracy", num_to_keep=4, ) run_config_progress_reporter = tune.JupyterNotebookReporter( metric_columns=[ "train_loss", "train_accuracy", "val_loss", "val_accuracy", ], parameter_columns=list(param_space.keys()), print_intermediate_tables=True, max_report_frequency=30, ) sync_cfg = tune.SyncConfig( sync_period=0, sync_timeout=1800, sync_artifacts=True, sync_artifacts_on_checkpoint=True, ) run_config = tune.RunConfig( name=experience, storage_path=work_dir, checkpoint_config=run_config_checkpoint_config, sync_config=sync_cfg, progress_reporter=run_config_progress_reporter, verbose=2, ) tuner = tune.Tuner( trainable=trainable, param_space=param_space, tune_config=tune_config, run_config=run_config, ) result_grid: tune.ResultGrid = tuner.fit() result_grid.get_dataframe().to_csv(result_path) best_result_acc: AirResult = result_grid.get_best_result( "val_accuracy", "max", filter_nan_and_inf=False ) test_metrics = test_model(best_result_acc, plot_dir=plot_dir) test_metrics = format_test_metrics(test_metrics, os.path.join(result_dir, "test_metrics.json")) metrics = best_result_acc.metrics or {} metrics = format_train_metrics(metrics, os.path.join(result_dir, "best_metrics.json")) config = best_result_acc.config or {} print("-" * 50) print("\nBest config is:") print( tabulate( config.items(), headers=["Parameters", "Value"], tablefmt="fancy_grid", ) ) print("\nTraining (Best) metrics is:") print(metrics) print("\nTest metrics is:") print(test_metrics) print("-" * 50) df = best_result_acc.metrics_dataframe if df is not None: plot_loss_curves(df, loss_path) plot_accuracy_curves(df, accuracy_path) else: print("Warning: No metrics dataframe available for plotting; skipping plots.") plot_all_trials_accuracy(result_grid, trials_acc_path=trials_acc_path) return result_gridif name == “main”:
ray.init(ignore_reinit_error=True, log_to_driver=True, configure_logging=True)
result_grid = main()
ray.shutdown()