How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Preamble
My goal is to implement a class, Trainer
, which internally initializes a PyTorch model based on the input I provide through the Ray config
file. For example, given the following hypothetical config entry {"model": "SingleLinearProbe"}
, Trainer
should initialize the appropriate model and proceed to train it and fine-tune the specified parameters.
Issue
I got a first draft of the class which inherits from ray.tune.Trainable
and overrides setup
, the problem is that I don’t understand how to override the step
method.
Specifically, my _train_and_validate
method contains two loops, iterating over the total number of epochs and the total number of batches in the data-loaders, respectively. I’m not sure how to adapt this code to the ray API. In other words, looking at the documentation, and at the PyTorch examples I found, the step
method doesn’t seem to require any looping. If that is the case, how can I re-implement the logic in my original method so that it plays well with step
?
My code
import torch
import torch.nn as nn
from data import TrainDataset
from pathlib import Path
from probes import SingleLinearProbe, DoubleLinearProbe
from ray.tune import Trainable
from typing import Any, Callable, Dict, Tuple
from torch.nn import MSELoss
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader, random_split
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Trainer(Trainable):
def setup(self, config: Dict[str, Any]) -> None:
self._model = self._initialize_model(config)
self._epochs = config["epochs"]
self._train_loader, self._val_loader = self._initialize_dataloaders(config)
self._criterion = self._get_criterion(config)
self._optimizer = self._get_optimizer(config)
self._batch_size = config["batch_size"]
def _initialize_model(self, config: Dict[str, Any]) -> nn.Module:
if config["model"] == "SingleLinearProbe":
model = SingleLinearProbe(
config["model"]["SingleLinearProbe"]["dim_input"],
config["model"]["SingleLinearProbe"]["dim_output"]
)
elif config["model"] == "DoubleLinearProbe":
model = DoubleLinearProbe(
config["model"]["DoubleLinearProbe"]["dim_input"],
config["model"]["DoubleLinearProbe"]["dim_hidden"],
config["model"]["DoubleLinearProbe"]["dim_output"]
)
else:
raise KeyError
return model
def _initialize_dataloaders(self, config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader]:
csv_data = Path(__file__).parent / "datasets" / config["train_data"]
dataset = TrainDataset(csv_data, config["language_model"])
train_data, val_data = random_split(
dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42)
)
train_dataloader = DataLoader(
train_data, batch_size=config["batch_size"], shuffle=True, num_workers=0
)
val_dataloader = DataLoader(
val_data, batch_size=config["batch_size"], shuffle=True, num_workers=0
)
return train_dataloader, val_dataloader
def _get_criterion(self, config: Dict[str, Any]) -> Callable:
if config["criterion"] == "MSELoss":
criterion = MSELoss
else:
raise KeyError
return criterion
def _get_optimizer(self, config: Dict[str, Any]) -> Optimizer:
if config["optimizer"] == "SGD":
optimizer = SGD(
self._model.parameters(),
config["optimizer"]["SGD"]["lr"],
config["optimizer"]["SGD"]["momentum"])
else:
raise KeyError
return optimizer
def step(self):
????
def _train_and_validate(self) -> None:
for epoch in range(self._epochs):
self._model.train()
# sum of all batch losses for a given epoch
train_loss = 0
val_loss = 0
for data, target in self._train_loader:
data = data.to(DEVICE)
target = target.to(DEVICE)
out = self._model(data)
batch_loss = self._criterion(target, out)
train_loss += batch_loss.item() * data.size(0)
self._optimizer.zero_grad()
batch_loss.backward()
self._optimizer.step()
# average per-sample loss for a given epoch
# https://discuss.pytorch.org/t/on-running-loss-and-average-loss/107890
# https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#training-the-model
avg_train_loss = train_loss / len(list(self._train_loader.sampler))
with torch.no_grad():
self._model.eval()
val_loss = 0
for data, target in self._val_loader:
data = data.to(DEVICE)
target = target.to(DEVICE)
out = self._model(data)
batch_loss = self._criterion(target, out)
val_loss += batch_loss.item() * data.size(0)
# compute average sample loss for a given epoch
avg_val_loss = val_loss / len(list(self._val_loader))
yield avg_train_loss, avg_val_loss # <== ???