How make Ray Tune + DDP training

How can I use Ray Tune + DDP PyTorch (not PyTorch Lightning)? I don’t see an example. Please review my code if you need context; everything works, but I’m unable to verify if it’s DDP training. I test it in Kaggle/Colab; I don’t have GPUS.

import socket
hostname = socket.gethostname()
print(f"Current hostname: {hostname}")
if hostname != “macbook-air-m2”:
!pip install -q “ray[tune, train]”
!pip install -q optuna torchmetrics
!pip install -qU tensorboardx
else:
print(“Running on local machine, skipping package installation.”)

import os
from itertools import count
from time import time

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import ray
import json
import ray.tune
import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from ray import tune
from ray.air.result import Result as AirResult
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from tabulate import tabulate
from torch.utils.data import DataLoader, random_split
from torchmetrics import MetricCollection
from torchmetrics.classification import (AUROC, EER, ROC, Accuracy,
AveragePrecision, CalibrationError,
CohenKappa, ConfusionMatrix,
ExactMatch, F1Score, MatthewsCorrCoef,
NegativePredictiveValue, Precision,
PrecisionRecallCurve, Recall,
Specificity)
from torchvision import datasets
from torchvision.transforms import v2

os.environ[“RAY_AIR_NEW_OUTPUT”] = “0”
os.environ[“RAY_TRAIN_V2_ENABLED”] = “1”
os.environ[“PYTORCH_ENABLE_MPS_FALLBACK”] = “1”

==============================================================

Model

==============================================================

class Net(nn.Module):
def init(self, l1, l2):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

==============================================================

Custom Loss Functions

==============================================================

class CustomLoss(nn.Module):
def forward(self, outputs, targets):
raise NotImplementedError

class NLLLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.NLLLoss()
self.log_softmax = nn.LogSoftmax(dim=1)

def forward(self, outputs, targets):
    return self.criterion(self.log_softmax(outputs), targets)

class KLDivLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.KLDivLoss(reduction=“batchmean”)
self.log_softmax = nn.LogSoftmax(dim=1)

def forward(self, outputs, targets):
    outputs = self.log_softmax(outputs)
    targets = F.one_hot(targets, num_classes=outputs.shape[1]).float()
    targets = targets / targets.sum(dim=1, keepdim=True)  # normalisation
    return self.criterion(outputs, targets)

class MSELossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.MSELoss()

def forward(self, outputs, targets):
    outputs = torch.sigmoid(outputs)
    targets = F.one_hot(targets, num_classes=10).float()
    return self.criterion(outputs, targets)

class MultiMarginLossV2(CustomLoss):
def init(self):
super().init()
self.criterion = nn.MultiMarginLoss()

def forward(self, outputs, targets):
    return self.criterion(outputs, targets)

class FocalLossV2(CustomLoss):
def init(self, alpha=1, gamma=2):
super().init()
self.alpha = alpha
self.gamma = gamma

def forward(self, inputs, targets):
    BCE_loss = F.cross_entropy(inputs, targets, reduction="none")
    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
    return F_loss.mean()

==============================================================

Loaders

==============================================================

def load_device() → torch.device:
if torch.backends.mps.is_available():
device = torch.device(“mps”)
elif torch.cuda.is_available():
device = torch.device(“cuda”)
else:
device = torch.device(“cpu”)
return device

def load_model() → dict:
models = {
“Net”: Net,
}
return models

def load_loss() → dict:
losses = {
“CrossEntropyLoss”: nn.CrossEntropyLoss(),
“NLLLoss”: NLLLossV2(),
“KLDivLoss”: KLDivLossV2(),
“MSELoss”: MSELossV2(),
“MultiMarginLoss”: MultiMarginLossV2(),
“FocalLoss”: FocalLossV2(),
}
return losses

def load_optimizer() → dict:
optimizers = {
“Adam”: optim.Adam,
“AdamW”: optim.AdamW,
“SGD”: optim.SGD,
“Adadelta”: optim.Adadelta,
“Adagrad”: optim.Adagrad,
“Adamax”: optim.Adamax,
“RMSprop”: optim.RMSprop,
“Rprop”: optim.Rprop,
}
return optimizers

def load_scheduler() → dict:
schedulers = {
“StepLR”: optim.lr_scheduler.StepLR,
“MultiStepLR”: optim.lr_scheduler.MultiStepLR,
“ExponentialLR”: optim.lr_scheduler.ExponentialLR,
“CosineAnnealingLR”: optim.lr_scheduler.CosineAnnealingLR,
“ReduceLROnPlateau”: optim.lr_scheduler.ReduceLROnPlateau,
}
return schedulers

def metric_collection(
task: str, num_classes: int
) → MetricCollection:
metrics = MetricCollection(
{
“Accuracy”: Accuracy(task=task, num_classes=num_classes),
“AUROC”: AUROC(task=task, num_classes=num_classes),
“AveragePrecision”: AveragePrecision(task=task, num_classes=num_classes),
“CalibrationError”: CalibrationError(task=task, num_classes=num_classes),
“CohenKappa”: CohenKappa(task=task, num_classes=num_classes),
“ConfusionMatrix”: ConfusionMatrix(task=task, num_classes=num_classes),
“EER”: EER(task=task, num_classes=num_classes),
“ExactMatch”: ExactMatch(task=task, num_classes=num_classes),
“Precision_macro”: Precision(
average=“macro”, task=task, num_classes=num_classes
),
“Recall_macro”: Recall(average=“macro”, task=task, num_classes=num_classes),
“F1Score_macro”: F1Score(
average=“macro”, task=task, num_classes=num_classes
),
“NegativePredictiveValue”: NegativePredictiveValue(
task=task, num_classes=num_classes
),
“MatthewsCorrCoef”: MatthewsCorrCoef(task=task, num_classes=num_classes),
“ROC”: ROC(task=task, num_classes=num_classes),
“Specificity”: Specificity(task=task, num_classes=num_classes),
“PrecisionRecallCurve”: PrecisionRecallCurve(
task=task, num_classes=num_classes
)
}
)
return metrics

==============================================================

Data Loaders

==============================================================

def get_data_loaders(config):
data_dir = config.get(“data_dir”)
batch_size = config.get(“batch_size”)
world_size = (
torch.cuda.device_count()
if torch.cuda.is_available()
else (os.cpu_count() or 1)
)
worker_batch_size = max(1, batch_size // world_size)

os.makedirs(data_dir, exist_ok=True)

transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

with FileLock(os.path.join(data_dir, ".data.lock")):
    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )
    test_val_dataset = datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )
    val_dataset, test_dataset = random_split(test_val_dataset, [0.5, 0.5])

train_loader = DataLoader(train_dataset, batch_size=worker_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=worker_batch_size)
test_loader = DataLoader(test_dataset, batch_size=worker_batch_size)

loaders = (train_loader, val_loader, test_loader)
return loaders

==============================================================

Training / Validation / Testing functions

==============================================================

def train_epoch(model, criterion, optimizer, train_loader, device, use_amp):
batch_train_loss, batch_train_accuracy = 0.0, 0
if use_amp:
scaler = torch.amp.GradScaler(device=device.type)
model.train()
for data, target in train_loader:
data, target = data.to(device, non_blocking=True), target.to(
device, non_blocking=True
)
optimizer.zero_grad()
if use_amp:
with torch.amp.autocast(device_type=device.type):
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

    batch_train_loss += loss.item()
    pred = output.argmax(dim=1)
    batch_train_accuracy += pred.eq(target).sum().item()

batch_train_loss /= len(train_loader)
batch_train_accuracy *= 100 / len(train_loader.dataset)

train_metrics = {
    "train_loss": batch_train_loss,
    "train_accuracy": batch_train_accuracy,
}
return train_metrics

@torch.inference_mode()
def val_epoch(model, criterion, val_loader, device, use_amp):
batch_val_loss, batch_val_accuracy = 0.0, 0
model.eval()
for data, target in val_loader:
data, target = data.to(device, non_blocking=True), target.to(
device, non_blocking=True
)
if use_amp:
with torch.amp.autocast(device_type=device.type):
output = model(data)
loss = criterion(output, target)
else:
output = model(data)
loss = criterion(output, target)

    batch_val_loss += loss.item()
    pred = output.argmax(dim=1)
    batch_val_accuracy += pred.eq(target).sum().item()

batch_val_loss /= len(val_loader)
batch_val_accuracy *= 100 / len(val_loader.dataset)

val_metrics = {
    "val_loss": batch_val_loss,
    "val_accuracy": batch_val_accuracy,
}
return val_metrics

@torch.inference_mode()
def test_model(best_result: AirResult, plot_dir: str):
device = load_device()
models = load_model()
losses = load_loss()
best_config = best_result.config
checkpoint_path = os.path.join(best_result.checkpoint.path, “checkpoint.pt”)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

model = models[best_config["model"]](
    best_config.get("l1"),
    best_config.get("l2"),
).to(device)
model_state = checkpoint["model"]
model.load_state_dict(model_state)

print(f"Model device: {next(model.parameters()).device}")

_, _, test_loader = get_data_loaders(config=best_result.config)

labels = test_loader.dataset.dataset.classes
num_classes = len(labels)
task = "multiclass" if num_classes > 2 else "binary"
metrics = metric_collection(task=task, num_classes=num_classes).to(device)

criterion = losses[best_config["loss"]]

test_loss, test_accuracy = 0.0, 0
use_amp = best_config.get("use_amp", False)
metrics.reset()
model.eval()
start_time = time()
for data, target in test_loader:
    data, target = data.to(device, non_blocking=True), target.to(
        device, non_blocking=True
    )

    if use_amp:
        with torch.amp.autocast(device_type=device.type):
            output = model(data)
            loss = criterion(output, target)
    else:
        output = model(data)
        loss = criterion(output, target)

    test_loss += loss.item()
    pred = output.argmax(dim=1)
    test_accuracy += pred.eq(target).sum().item()

    metrics.update(output, target)

end_time = time() - start_time
test_loss /= len(test_loader)
test_accuracy *= 100 / len(test_loader.dataset)

keys_plot = ["ConfusionMatrix", "ROC", "PrecisionRecallCurve"]
test_metrics = {
    "Loss": round(test_loss, 6),
    "Accuracy (%)": round(test_accuracy, 2),
    "Time (s)": round(end_time, 2),
}

scores = metrics.compute()
eer = scores.get("EER").detach().cpu().numpy().tolist()
scores["EER"] = [round(v, 6) for v in eer]
filtered_results = {
    k: (
        v
        if k == "EER"
        else round(v.item(), 6)
    )
    for k, v in scores.items()
    if k not in keys_plot
}

test_metrics.update(filtered_results)

metric_plot = MetricCollection({k: metrics[k] for k in keys_plot}).to(device)
for k, v in metric_plot.items():
  fig, ax = plt.subplots(figsize=(8, 8))
  if k == "ROC":
      v.plot(score=True, labels=labels, ax=ax)
  elif k == "PrecisionRecallCurve":
      v.plot(score=True, ax=ax)
  else:
      v.plot(labels=labels, ax=ax)

  ax.set_aspect("equal", adjustable="box")
  ax.set_title(k, fontsize=10)
  handles, labels_ = ax.get_legend_handles_labels()
  if handles:
      ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5), fontsize=8)
  fig.subplots_adjust(right=0.75)
  plt.savefig(os.path.join(plot_dir, f"{k}.svg"), bbox_inches="tight")
  plt.show()

return test_metrics

=============================================================

Utils

=============================================================

def format_train_metrics(metrics: dict, path: str) → str:
rename_map = {
“train_loss”: “Train Loss”,
“train_accuracy”: “Train Accuracy (%)”,
“val_loss”: “Validation Loss”,
“val_accuracy”: “Validation Accuracy (%)”,
}

formatted = {}
ordered_keys = []

for key in ["train_loss", "train_accuracy", "val_loss", "val_accuracy"]:
    if key in metrics:
        name = rename_map[key]
        val = metrics[key]
        if isinstance(val, float):
            if "loss" in key:
                val = round(val, 6)
            elif "accuracy" in key:
                val = round(val, 2)
        formatted[name] = val
        ordered_keys.append(name)

if "time_total_s" in metrics and "training_iteration" in metrics:
    avg_time = metrics["time_total_s"] / metrics["training_iteration"]
    formatted["Average time per epoch (s)"] = round(avg_time, 2)
    ordered_keys.append("Average time per epoch (s)")

for key, val in metrics.items():
    if key in ["config", "experiment_tag"] or key in rename_map:
        continue
    if isinstance(val, float):
        val = round(val, 2)
    formatted[key] = val
    ordered_keys.append(key)

with open(path, "w") as f:
    json.dump(formatted, f, indent=4)

rows = [(k, formatted[k]) for k in ordered_keys]
table = tabulate(rows, headers=["Metric", "Value"], tablefmt="fancy_grid")
return table

def format_test_metrics(metrics: dict, save_path: str) → str:
formatted = {}

for key, val in metrics.items():
    if key.lower() == "accuracy":
        continue
    elif key == "Accuracy (%)":
        formatted["Accuracy (%)"] = round(val, 2)
    elif isinstance(val, float):
        formatted[key] = round(val, 6)
    else:
        formatted[key] = val

with open(save_path, "w") as f:
    json.dump(formatted, f, indent=4)

rows = list(formatted.items())
table = tabulate(rows, headers=["Metric", "Value"], tablefmt="fancy_grid")
return table

==============================================================

Plotting

==============================================================

def plot_loss_curves(df: pd.DataFrame, loss_path: str, title: str = “Loss curves”):
figsize = (16, 10)
col_loss = [“train_loss”, “val_loss”]

fig, ax = plt.subplots(figsize=figsize)
df.plot(
    x="training_iteration",
    y=col_loss,
    marker="v",
    markersize=6,
    linestyle="-",
    linewidth=1,
    ax=ax,
    title=title,
    ylabel="Loss",
    xlabel="Training Iteration",
)

xmin = int(df["training_iteration"].min())
xmax = int(df["training_iteration"].max())
ax.set_xlim(left=max(1, xmin), right=xmax)

locator = ticker.MaxNLocator(nbins=12, integer=True)
ax.xaxis.set_major_locator(locator)
xticks = list(ax.get_xticks())
if 1 not in xticks:
    xticks = [1] + [t for t in xticks if t > 1]
ax.set_xticks(xticks)

ymin = 0
ymax = df[col_loss].max().max()
ax.set_ylim(bottom=ymin, top=ymax)

yticks = np.linspace(ymin, ymax, num=10).tolist()
if ymax not in yticks:
    yticks.append(ymax)
yticks = sorted(set(yticks))
ax.set_yticks(yticks)

ax.yaxis.set_major_formatter(
    ticker.FuncFormatter(lambda x, _: f"{int(x)}" if x.is_integer() else f"{x:.2f}")
)

ax.legend(col_loss)

fig.tight_layout()
fig.savefig(loss_path)

def plot_accuracy_curves(
df: pd.DataFrame, acc_path: str, title: str = “Accuracy curves”
):
figsize = (16, 10)
col_acc = [“train_accuracy”, “val_accuracy”]

fig, ax = plt.subplots(figsize=figsize)

df.plot(
    x="training_iteration",
    y=col_acc,
    marker="^",
    markersize=6,
    linestyle="-",
    linewidth=1,
    ax=ax,
    title=title,
    ylabel="Accuracy (%)",
    xlabel="Training Iteration",
)

xmin = int(df["training_iteration"].min())
xmax = int(df["training_iteration"].max())
ax.set_xlim(left=max(1, xmin), right=xmax)

locator = ticker.MaxNLocator(nbins=12, integer=True)
ax.xaxis.set_major_locator(locator)
xticks = list(ax.get_xticks())
if 1 not in xticks:
    xticks = [1] + [t for t in xticks if t > 1]
if xmax not in xticks:
    xticks = [t for t in xticks if t < xmax] + [xmax]
ax.set_xticks(xticks)

ymin = df[col_acc].min().min()
ymax = 100
ax.set_ylim(ymin, ymax)

step = 5
yticks = list(range(int(np.floor(ymin)), ymax + 1, step))
if ymax not in yticks:
    yticks.append(ymax)
ax.set_yticks(yticks)

ax.legend(col_acc)

fig.tight_layout()
fig.savefig(acc_path)

def plot_all_trials_accuracy(result_grid: ray.tune.ResultGrid, trials_acc_path: str):
figsize = (16, 10)
fig, ax = plt.subplots(figsize=figsize)

all_vals = (
    np.concatenate(
        [
            r.metrics_dataframe["val_accuracy"].values  # type: ignore
            for r in result_grid
            if r.metrics_dataframe is not None
            and "val_accuracy" in r.metrics_dataframe
        ]
    )
    if result_grid
    else np.array([])
)

xmax = max([r.metrics_dataframe["training_iteration"].max() for r in result_grid])
xmin = 1
xlim = (xmin, xmax)
xticks = np.linspace(xmin, xmax, num=12, dtype=int)
if 1 not in xticks:
    xticks = np.insert(xticks, 0, 1)
if xmax not in xticks:
    xticks = np.append(xticks, xmax)

ymin = int(all_vals.min()) if len(all_vals) > 0 else 0
ymax = 100
ylim = (ymin, ymax)
yticks = list(range(ymin, ymax + 1, 5))
if ymax not in yticks:
    yticks.append(ymax)

configs = []
for idx, result in enumerate(result_grid):
    df = result.metrics_dataframe
    if df is not None and "val_accuracy" in df:
        label = "".join([f"{k}={v}, \n" for k, v in result.config.items()])
        df.plot(
            x="training_iteration",
            y="val_accuracy",
            ax=ax,
            label=label,
            marker="o",
            markersize=3,
            linestyle="-",
            linewidth=1,
        )
        configs.append({"trial": f"Trial {idx+1}", **result.config})

ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xticks(xticks)
ax.set_yticks(yticks)
ax.set_ylabel("Test Accuracy")
ax.set_xlabel("Training Iteration")
ax.set_title("Accuracy for All Trials")

handles, labels = ax.get_legend_handles_labels()
if handles:
    legend_fig, legend_ax = plt.subplots(figsize=(4, len(labels) * 0.3))
    legend_ax.axis("off")
    legend_ax.legend(
        handles,
        labels,
        loc="center",
        fontsize=7,
        ncol=1,
    )
    legend_fig.savefig(
        trials_acc_path.replace(".svg", "_legend.svg"),
        bbox_inches="tight",
    )
    plt.close(legend_fig)

leg = ax.get_legend()
if leg:
    leg.remove()
fig.savefig(trials_acc_path, bbox_inches="tight")
pd.DataFrame(configs).to_csv(trials_acc_path.replace(".svg", ".csv"), index=False)

==============================================================

Ray Trainable

==============================================================

class TrainCnn(tune.Trainable):
def setup(self, config):
self.device = load_device()
self.use_amp = config.get(“use_amp”, False)
models = load_model()
loss = load_loss()
optimizer = load_optimizer()
self.data_dir = config.get(“data_dir”)
self.batch_size = config.get(“batch_size”)
self.epoch = config.get(“epoch”)
self.model = models[config[“model”]](
config.get(“l1”),
config.get(“l2”),
).to(self.device)

    self.criterion = loss[config["loss"]]
    self.optimizer = optimizer[config["optimizer"]](
        self.model.parameters(), lr=config.get("lr")
    )
    self.train_loader, self.val_loader, self.test_loader = get_data_loaders(config)

def step(self):
    # ✅ Check DDP
    if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
        print("Model is wrapped with DistributedDataParallel")
    else:
        print("Model is NOT in DDP")

    # ✅ Check DistributedSampler
    train_sampler = getattr(self.train_loader, "sampler", None)
    if isinstance(train_sampler, torch.utils.data.distributed.DistributedSampler):
        print("Train loader uses DistributedSampler")
    else:
        print("Train loader does NOT use DistributedSampler")

    print("Model device:", next(self.model.parameters()).device)

    train_metrics = train_epoch(
        model=self.model,
        criterion=self.criterion,
        optimizer=self.optimizer,
        train_loader=self.train_loader,
        device=self.device,
        use_amp=self.use_amp,
    )

    val_metrics = val_epoch(
        model=self.model,
        criterion=self.criterion,
        val_loader=self.val_loader,
        device=self.device,
        use_amp=self.use_amp,
    )
    metrics = {**train_metrics, **val_metrics}
    return metrics

def save_checkpoint(self, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
    torch.save(
        {
            "epoch": self.epoch,
            "model": self.model.module.state_dict() if hasattr(self.model, "module") else self.model.state_dict(),
            "loss": self.criterion.__class__.__name__,
            "optimizer": self.optimizer.state_dict(),
        },
        checkpoint_path,
    )

def load_checkpoint(self, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
    checkpoint = torch.load(checkpoint_path, map_location=self.device)
    self.epoch = checkpoint["epoch"]
    model_state = checkpoint["model"]
    self.model.load_state_dict(model_state)
    loss = load_loss()
    self.criterion = loss[checkpoint["loss"]]
    optimizer_state = checkpoint["optimizer"]
    self.optimizer.load_state_dict(optimizer_state)

def reset_config(self, new_config):
    self.epoch = new_config.get("epoch")
    self.data_dir = new_config.get("data_dir")
    self.batch_size = new_config.get("batch_size")
    self.train_loader, self.val_loader, self.test_loader = get_data_loaders(
        new_config
    )
    models = load_model()
    loss = load_loss()
    optimizer = load_optimizer()
    self.model = models[new_config["model"]](
        new_config.get("l1"),
        new_config.get("l2"),
    ).to(self.device)
    self.criterion = loss[new_config["loss"]]
    self.optimizer = optimizer[new_config["optimizer"]](
        self.model.parameters(), lr=new_config.get("lr")
    )
    self.config = new_config
    return True

def custom_trial_name_creator(trial):
return f"trial_{trial.trial_id}"

def custom_trial_dirname_creator(trial):
return f"trial_{trial.trial_id}"

def main():
experience = “cifar10”
work_dir = “/kaggle/working/laboai/results/”
data_dir = os.path.join(work_dir, experience, “data”)
result_dir = os.path.join(work_dir, experience, “results”)
plot_dir = os.path.join(work_dir, experience, “plots”)
os.makedirs(result_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)
result_path = os.path.join(result_dir, “results.csv”)
loss_path = os.path.join(plot_dir, “loss.svg”)
accuracy_path = os.path.join(plot_dir, “accuracy.svg”)
trials_acc_path = os.path.join(plot_dir, “trials_accuracy.svg”)

lr = [round(10**-i, 10) for i in range(7)] + [
    round(3 * 10**-i, 10) for i in range(1, 7)
]

epochs = 2
num_samples = 2

cpus = os.cpu_count()
gpus = getattr(torch, load_device().type).device_count() if load_device() else 0

cpu_per_trial = cpus
gpu_per_trial = gpus

trainable = tune.with_resources(
    TrainCnn, resources={"cpu": float(cpu_per_trial), "gpu": float(gpu_per_trial)}
)

models = load_model()
loss = load_loss()
optimizer = load_optimizer()
param_space = {
    "data_dir": data_dir,
    "use_amp": tune.choice([False]),
    "batch_size": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14)),
    "model": tune.choice(list(models.keys())),
    "lr": tune.choice(lr),
    "loss": tune.choice(list(loss.keys())),
    "optimizer": tune.choice(list(optimizer.keys())),
    "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14)),
    "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 14))
}

search_alg = OptunaSearch(
    metric="val_accuracy",
    mode="max",
)

scheduler = ASHAScheduler(
    time_attr="training_iteration",
    max_t=epochs,
    grace_period=epochs,
    reduction_factor=3,
    brackets=3
)

max_concurrent = int(max(1, cpus // cpu_per_trial))
tune_config = tune.TuneConfig(
    mode="max",
    metric="val_accuracy",
    search_alg=search_alg,
    scheduler=scheduler,
    num_samples=num_samples,
    max_concurrent_trials=max_concurrent,
    reuse_actors=True,
    trial_name_creator=custom_trial_name_creator,
    trial_dirname_creator=custom_trial_dirname_creator
)

run_config_checkpoint_config = tune.CheckpointConfig(
    checkpoint_at_end=True,
    checkpoint_frequency=4,
    checkpoint_score_order="max",
    checkpoint_score_attribute="val_accuracy",
    num_to_keep=4,
)

run_config_progress_reporter = tune.JupyterNotebookReporter(
    metric_columns=[
        "train_loss",
        "train_accuracy",
        "val_loss",
        "val_accuracy",
    ],
    parameter_columns=list(param_space.keys()),
    print_intermediate_tables=True,
    max_report_frequency=30,
)

sync_cfg = tune.SyncConfig(
    sync_period=0,
    sync_timeout=1800,
    sync_artifacts=True,
    sync_artifacts_on_checkpoint=True,
)

run_config = tune.RunConfig(
    name=experience,
    storage_path=work_dir,
    checkpoint_config=run_config_checkpoint_config,
    sync_config=sync_cfg,
    progress_reporter=run_config_progress_reporter,
    verbose=2,
)

tuner = tune.Tuner(
    trainable=trainable,
    param_space=param_space,
    tune_config=tune_config,
    run_config=run_config,
)

result_grid: tune.ResultGrid = tuner.fit()
result_grid.get_dataframe().to_csv(result_path)
best_result_acc: AirResult = result_grid.get_best_result(
    "val_accuracy", "max", filter_nan_and_inf=False
)

test_metrics = test_model(best_result_acc, plot_dir=plot_dir)
test_metrics = format_test_metrics(test_metrics, os.path.join(result_dir, "test_metrics.json"))

metrics = best_result_acc.metrics or {}
metrics = format_train_metrics(metrics, os.path.join(result_dir, "best_metrics.json"))
config = best_result_acc.config or {}

print("-" * 50)
print("\nBest config is:")
print(
    tabulate(
        config.items(),
        headers=["Parameters", "Value"],
        tablefmt="fancy_grid",
    )
)

print("\nTraining (Best) metrics is:")
print(metrics)

print("\nTest metrics is:")
print(test_metrics)

print("-" * 50)

df = best_result_acc.metrics_dataframe
if df is not None:
    plot_loss_curves(df, loss_path)
    plot_accuracy_curves(df, accuracy_path)
else:
    print("Warning: No metrics dataframe available for plotting; skipping plots.")

plot_all_trials_accuracy(result_grid, trials_acc_path=trials_acc_path)

return result_grid

if name == “main”:
ray.init(ignore_reinit_error=True, log_to_driver=True, configure_logging=True)
result_grid = main()
ray.shutdown()

Hello! Have you looked into using Ray’s TorchTrainer or Ray Train’s distributed utilities (which automatically handle DDP setup and device placement)? I think you’ll need to use ray.train.torch.prepare_model() or ray.train.torch.prepare_data_loader() for DDP + Ray stuff.

A guide that might help: ray/doc/source/train/getting-started-pytorch.rst at releases/2.47.1 · ray-project/ray · GitHub

I have already set up DDP but when I save the model like that I never find my checkpoint.

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:

checkpoint = None

if ray.train.get_context().get_world_rank() == 0:

    ckpt_path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")

save_checkpoint(model, optimizer, epoch, criterion, “checkpoint.pt”)

    checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

ray.train.report(*metrics*=metrics, *checkpoint*=checkpoint)

My best_result from result_grid.

Result( metrics={‘epoch’: 1, ‘train_loss’: 2.308340251865253, ‘train_acc’: 4.952, ‘val_loss’: 2.3058702626805396, ‘val_acc’: 5.36}, path=‘/results/cifar10_ddp/train_driver_fn_90f8f685_1_batch_size=32,data_dir=results_cifar10_ddp_data,epochs=1,l1=fn_ph_b4e22907,l2=fn__2025-09-23_15-04-42’, filesystem=‘local’, checkpoint=None )