How do I override `step` for a class that inherits from `ray.tune.Trainable`?

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Preamble
My goal is to implement a class, Trainer, which internally initializes a PyTorch model based on the input I provide through the Ray config file. For example, given the following hypothetical config entry {"model": "SingleLinearProbe"}, Trainer should initialize the appropriate model and proceed to train it and fine-tune the specified parameters.

Issue
I got a first draft of the class which inherits from ray.tune.Trainable and overrides setup, the problem is that I don’t understand how to override the step method.

Specifically, my _train_and_validate method contains two loops, iterating over the total number of epochs and the total number of batches in the data-loaders, respectively. I’m not sure how to adapt this code to the ray API. In other words, looking at the documentation, and at the PyTorch examples I found, the step method doesn’t seem to require any looping. If that is the case, how can I re-implement the logic in my original method so that it plays well with step?

My code

import torch
import torch.nn as nn

from data import TrainDataset
from pathlib import Path
from probes import SingleLinearProbe, DoubleLinearProbe
from ray.tune import Trainable
from typing import Any, Callable, Dict, Tuple
from torch.nn import MSELoss
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader, random_split

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Trainer(Trainable):

    def setup(self, config: Dict[str, Any]) -> None:
        self._model = self._initialize_model(config)
        self._epochs = config["epochs"]
        self._train_loader, self._val_loader = self._initialize_dataloaders(config)
        self._criterion = self._get_criterion(config)
        self._optimizer = self._get_optimizer(config)
        self._batch_size = config["batch_size"]


    def _initialize_model(self, config: Dict[str, Any]) -> nn.Module:
        if config["model"] == "SingleLinearProbe":
            model = SingleLinearProbe(
                config["model"]["SingleLinearProbe"]["dim_input"],
                config["model"]["SingleLinearProbe"]["dim_output"]
            )
        elif config["model"] == "DoubleLinearProbe":
            model = DoubleLinearProbe(
                config["model"]["DoubleLinearProbe"]["dim_input"],
                config["model"]["DoubleLinearProbe"]["dim_hidden"],
                config["model"]["DoubleLinearProbe"]["dim_output"]
            )
        else:
            raise KeyError

        return model


    def _initialize_dataloaders(self, config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader]:
        csv_data = Path(__file__).parent / "datasets" / config["train_data"]
        dataset = TrainDataset(csv_data, config["language_model"])
        train_data, val_data = random_split(
            dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42)
        )
        train_dataloader = DataLoader(
            train_data, batch_size=config["batch_size"], shuffle=True, num_workers=0
        )
        val_dataloader = DataLoader(
            val_data, batch_size=config["batch_size"], shuffle=True, num_workers=0
        )

        return train_dataloader, val_dataloader


    def _get_criterion(self, config: Dict[str, Any]) -> Callable:
        if config["criterion"] == "MSELoss":
            criterion = MSELoss
        else:
            raise KeyError

        return criterion


    def _get_optimizer(self, config: Dict[str, Any]) -> Optimizer:
        if config["optimizer"] == "SGD":
            optimizer = SGD(
                self._model.parameters(),
                config["optimizer"]["SGD"]["lr"],
                config["optimizer"]["SGD"]["momentum"])
        else:
            raise KeyError
        
        return optimizer


    def step(self):
        ????


    def _train_and_validate(self) -> None:
        for epoch in range(self._epochs):
            self._model.train()
            # sum of all batch losses for a given epoch
            train_loss = 0
            val_loss = 0

            for data, target in self._train_loader:
                data = data.to(DEVICE)
                target = target.to(DEVICE)
                
                out = self._model(data)
                batch_loss = self._criterion(target, out)
                train_loss += batch_loss.item() * data.size(0)

                self._optimizer.zero_grad()
                batch_loss.backward()

                self._optimizer.step()
            
            # average per-sample loss for a given epoch
            # https://discuss.pytorch.org/t/on-running-loss-and-average-loss/107890
            # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#training-the-model
            avg_train_loss = train_loss / len(list(self._train_loader.sampler))

            with torch.no_grad():
                self._model.eval()
                val_loss = 0
                for data, target in self._val_loader:
                    data = data.to(DEVICE)
                    target = target.to(DEVICE)

                    out = self._model(data)
                    batch_loss = self._criterion(target, out)
                    val_loss += batch_loss.item() * data.size(0)

                # compute average sample loss for a given epoch
                avg_val_loss = val_loss / len(list(self._val_loader))

            yield avg_train_loss, avg_val_loss  # <== ???

Hi @mtt,

Have you considered using Tune’s Function API instead? That way, you don’t need to change your existing loop too much. See here: Training in Tune (tune.Trainable, session.report) — Ray 2.3.0

A step can be whatever granularity you define. The only thing that it affects for Tune is reporting and checkpointing. For example, if you implement step to be 1 epoch of training, then metrics and checkpoints will only happen every epoch. If step is one gradient step (one batch), then you’ll have many more reports happening.

You implement the stopping logic after self._epochs by using a Tune stopper: Stopping and Resuming a Tune Run — Ray 2.3.0

Hi @justinvyu , thank you for taking the time to look into this.

In the end I managed to put together a working solution using Tune’s Trainable Class API. The one below is just a rough sketch of what it looks like. I still have a doubt though. If I want the step method to be called n times, where n is the number o epochs, how should I go about it? Right now I’m relying on a solution that resembles the snippet below. Am I right to assume that air.RunConfig(stop={"training_iteration": 10}) means that step is going to be executed 10 times?

tune.Tuner(
    my_trainable,
    run_config=air.RunConfig(stop={"training_iteration": 10})
).fit()

A sketch of my current solution

class Trainer(Trainable):
    def setup(self, config: Dict[str, Any]):
         ...

    def _get_model(self, config: Dict[str, Any]) -> nn.Module:
         ...

    def _get_dataloader(self, config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader]:
         ...

    def _get_criterion(self, config) -> _Loss:
         ...

    def _get_optimizer(self, config) -> Optimizer:
         ...

    def step(self):
        avg_val_loss = self._validate()
        avg_train_loss = self._train()

        return {"val_loss": avg_val_loss, "train_loss": avg_train_loss}


    def _train(self):
        self._model.train()

        # sum of all batch losses for a given epoch
        train_loss = 0

        for data, target in self._train_loader:
            data = data.to(DEVICE)
            target = target.to(DEVICE)
            
            out = self._model(data)
            batch_loss = self._criterion(target, out)
            train_loss += batch_loss.item() * data.size(0)

            self._optimizer.zero_grad()
            batch_loss.backward()

            self._optimizer.step()
        
        # compute average sample loss for a given epoch
        # https://discuss.pytorch.org/t/on-running-loss-and-average-loss/107890
        # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#training-the-model
        avg_train_loss = train_loss / len(list(self._train_loader.sampler))

        return avg_train_loss


    @torch.no_grad()
    def _validate(self):
        self._model.eval()

        # sum of all batch losses for a given epoch
        val_loss = 0

        for data, target in self._val_loader:
            data = data.to(DEVICE)
            target = target.to(DEVICE)

            out = self._model(data)
            batch_loss = self._criterion(target, out)
            val_loss += batch_loss.item() * data.size(0)

        # compute average sample loss for a given epoch
        avg_val_loss = val_loss / len(list(self._val_loader.sampler))

        return avg_val_loss


    def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
         ...

Yes, training_iteration is the number of times Trainable.step gets called.