Hi Kai, thank you for your help! I understand that ray does work better on linux and indeed my script runs perfectly on linux.
I’ve tried using WSL2 in the past however it results in the GPU not recognised by ray (even though other libraries like pytorch recognises there are GPU’s) which is a problem for large scale experiments that exceed 100 trials.
I’ve spent a little time to produce a “simple” script that doesn’t take 4 hours to reach the error message (previous attempts have resulted in the error message not showing anymore):
main script to run from command line:
import logging
import os
from functools import partial
from importlib import import_module
from pathlib import Path
import hydra
import ray
import ray.tune as tune
from hydra.utils import instantiate
from omegaconf import OmegaConf
from ray.tune import CLIReporter
from srcs.utils import set_seed, open_struct
from srcs.utils.tune import trial_name, get_metric_names
logger = logging.getLogger(__name__)
# when on windows, scipy bug causes ray tune to not save trials properly, see https://stackoverflow.com/questions/15457786/ctrl-c-crashes-python-after-importing-scipy-stats
if os.name == 'nt':
import _thread
import win32api
def handler(dwCtrlType, hook_sigint=_thread.interrupt_main):
if dwCtrlType == 0: # CTRL_C_EVENT
hook_sigint()
return 1 # don't chain to the next handler
return 0 # chain to the next handler
win32api.SetConsoleCtrlHandler(handler, 1)
@hydra.main(config_path='conf/', config_name='scratch', version_base='1.2')
def main(config):
output_dir = Path(hydra.utils.HydraConfig.get().run.dir)
if config.resume is not None:
output_dir = output_dir.parent / config.resume
ray.init(runtime_env={"env_vars": OmegaConf.to_container(hydra.utils.HydraConfig.get().job.env_set)},
)
further_train = False
if not config.train_func.get("keep_pth", True):
logger.warning("Checkpoint pth files are not saved!")
further_train = True
best_trial, best_checkpoint_dir = main_worker(config, output_dir)
print(best_trial)
print("I finished!")
def main_worker(config, output_dir, _post_tune=False):
OmegaConf.resolve(config)
if config.get("seed"):
set_seed(config.seed)
train_met_names = get_metric_names(config.metrics, prefix="train/", include_loss=True)
val_met_names = get_metric_names(config.metrics, prefix="val/", include_loss=True)
reporter = CLIReporter(metric_columns=["training_iteration"] + val_met_names)
assert config.run.config.get("wandb",
False), "This pipeline requires wandb to be configured with at least project name in config"
wandb_cfg = config.run.config.wandb
if wandb_cfg.get("group") is None:
with open_struct(wandb_cfg):
wandb_cfg["group"] = f"{config.name}-{str(output_dir.name).replace('-', '')}"
module_name, func_name = config.train_func._target_.rsplit('.', 1)
train_func = getattr(import_module(module_name), func_name)
train_func_args = OmegaConf.to_container(config.train_func)
train_func_args.pop("_target_")
partial_func = partial(train_func, **train_func_args, arch_cfg=config)
if hasattr(train_func, "__mixins__"):
partial_func.__mixins__ = train_func.__mixins__
analysis = tune.run(
partial_func,
**(instantiate(config.run, _convert_="partial")),
progress_reporter=reporter,
local_dir=output_dir.parent,
name=output_dir.name if config.resume is None else config.resume,
trial_dirname_creator=trial_name,
resume=config.resume is not None
)
monitor = config.monitor
mode = config.mode
best_trial = analysis.get_best_trial(metric=monitor, mode=mode, scope='all')
best_checkpoint_dir = analysis.get_best_checkpoint(best_trial, metric=monitor, mode=mode).local_path
return best_trial, best_checkpoint_dir
if __name__ == '__main__':
main()
train function passed to tune.run (referenced as srcs.trainer.tune_trainer.train_func
:
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from ray import tune
from ray.tune.utils.util import flatten_dict
from ray.tune.integration.wandb import wandb_mixin
from hydra.utils import instantiate
from srcs.loggers.logger import MetricCollection
from srcs.utils import prepare_devices, set_seed, open_struct
from srcs.utils.files import change_directory
import wandb
@wandb_mixin
def train_func(config, arch_cfg, checkpoint_dir=None, keep_pth=True, epochs=100, checkpoint_name=None):
# cwd is changed to the trial folder
project_dir = os.getenv('TUNE_ORIG_WORKING_DIR')
config = OmegaConf.create(config)
arch_cfg = OmegaConf.merge(arch_cfg, config)
wandb.log_artifact("config.yaml", type="config")
_config = arch_cfg.copy()
with open_struct(_config):
_config.run.pop("config")
_config = OmegaConf.to_container(_config)
_config = flatten_dict(_config, delimiter='-')
wandb.config.update(_config, allow_val_change=True)
# set seed for each run
if arch_cfg.get("seed"):
set_seed(arch_cfg.seed)
if checkpoint_name is None:
checkpoint_name = "model_checkpoint.pth"
with change_directory(project_dir):
# setup code
device, device_ids = prepare_devices(arch_cfg.n_gpu)
# setup dataloaders
data_loader = instantiate(arch_cfg.data_loader)
valid_data_loader = data_loader.split_validation()
output_size = len(data_loader.categories)
channel_n = data_loader.channel_n
# setup model
model = instantiate(arch_cfg.arch, output_size=output_size, channel_n=channel_n)
wandb.watch(model, log="all", log_freq=1)
wandb.define_metric(name=arch_cfg.monitor, summary=arch_cfg.mode)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
trainable_params = sum([p.numel() for p in trainable_params])
# logger.info(f'Trainable parameters: {sum([p.numel() for p in trainable_params])}')
wandb.summary["trainable_params"]=trainable_params
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
criterion = instantiate(arch_cfg.loss)
optimizer = instantiate(arch_cfg.optimizer, model.parameters())
lr_scheduler = None
if arch_cfg.get("lr_scheduler"):
lr_scheduler = instantiate(arch_cfg.lr_scheduler, optimizer)
train_metrics = MetricCollection(arch_cfg.metrics, prefix="train/")
valid_metrics = MetricCollection(arch_cfg.metrics, prefix="val/")
# later changed if checkpoint
start_epoch = 0
if checkpoint_dir:
checkpoint = torch.load(checkpoint_dir)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint["epoch"] + 1
for epoch in range(start_epoch, epochs): # loop over the dataset multiple times
train_metrics.reset()
train_log = one_epoch(data_loader, criterion, model, device, train_metrics, optimizer)
valid_metrics.reset()
with torch.no_grad():
val_log = one_epoch(valid_data_loader, criterion, model, device, valid_metrics)
if lr_scheduler is not None:
lr_scheduler.step()
state = {
'arch': type(model).__name__,
'epoch': epoch,
'state_dict': model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": arch_cfg
}
# create checkpoint
with tune.checkpoint_dir(f"epoch-{epoch}") as checkpoint_dir:
if keep_pth:
filename = Path(checkpoint_dir) / checkpoint_name
torch.save(state, filename)
# log metrics, log in checkpoint in case actor dies half way
val_log.update(train_log)
wandb.log(val_log, step=epoch)
tune.report(**val_log)
# tune.report(**train_log)
def one_epoch(data_loader, criterion, model, device, metric_tracker: MetricCollection, optimizer=None) -> dict:
for i, data in enumerate(data_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, targets = data
inputs, targets = inputs.to(device), targets.to(device)
if optimizer is not None:
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward
outputs = model(inputs)
loss = criterion(outputs, targets)
if optimizer is not None:
loss.backward()
optimizer.step()
metric_tracker.update(outputs, targets, loss)
return metric_tracker.result()
custom utility functions used in above code:
import contextlib
import numpy as np
import torch
from omegaconf import OmegaConf
from pathlib import Path
@contextlib.contextmanager
def open_struct(config):
OmegaConf.set_struct(config, False)
try:
yield
finally:
OmegaConf.set_struct(config, True)
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
def prepare_devices(n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
logger.warning("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
"on this machine.".format(n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
@contextlib.contextmanager
def change_directory(path):
"""Changes working directory and returns to previous on exit."""
prev_cwd = Path.cwd()
if path is None:
path = prev_cwd
os.chdir(path)
try:
yield
finally:
os.chdir(prev_cwd)
MetricCollection class:
import logging
import pandas as pd
import torch
from hydra.utils import instantiate
logger = logging.getLogger('logger')
class BatchMetrics:
def __init__(self, *keys, postfix=''):
self.postfix = postfix
if postfix:
keys = [k + postfix for k in keys]
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
self.reset()
def reset(self):
for col in self._data.columns:
self._data[col].values[:] = 0
def update(self, key, value, n=1):
if self.postfix:
key = key + self.postfix
self._data.total[key] += value * n
self._data.counts[key] += n
self._data.average[key] = self._data.total[key] / self._data.counts[key]
def avg(self, key):
if self.postfix:
key = key + self.postfix
return self._data.average[key]
def result(self):
return dict(self._data.average)
class MetricCollection:
def __init__(self, metric_dict, prefix='', postfix=''):
self.prefix=prefix
self.postfix=postfix
self.metric_ftns=[instantiate(met, _partial_=True) for met in metric_dict]
self.met_names=[self.prefix+met.func.__name__+self.postfix for met in self.metric_ftns] + [self.prefix+'loss'+self.postfix]
self.metric_tracker = BatchMetrics(*self.met_names)
def update(self, output, target, loss):
self.metric_tracker.update(self.prefix+'loss'+self.postfix, loss.item())
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
pred=pred.detach().cpu().numpy()
target = target.detach().cpu().numpy()
with torch.no_grad():
for met in self.metric_ftns:
result = met(target, pred)
self.metric_tracker.update(self.prefix+met.func.__name__+self.postfix, result)
def result(self):
return self.metric_tracker.result()
def reset(self):
self.metric_tracker.reset()
accuracy function:
def accuracy(target, output):
correct = 0
correct += np.sum(output == target)
return correct / len(target)
Then finally, the yaml file used to run hydra:
defaults:
- _self_
- optional local_env: env
data_loader:
_target_: srcs.data_loader.dummy_loader.DummyLoader
batch_size: 50
validation_split: 0.15
test_split: 0.1
num_workers: ${num_workers}
loss:
_target_: torch.nn.NLLLoss
optimizer:
_target_: torch.optim.AdamW
lr: 0.001
weight_decay: 0.01
amsgrad: true
name: dummy
arch:
_target_: srcs.models.dummy_model.DummyModel
a: 3
train_func:
_target_: srcs.trainer.tune_trainer.train_func
epochs: 5
keep_pth: false
checkpoint_name: model_checkpoint.pth
run:
metric: ${monitor}
mode: ${mode}
resources_per_trial:
cpu: ${num_workers}
gpu: 0.2
verbose: 2
keep_checkpoints_num: 8
checkpoint_score_attr: ${monitor}
scheduler:
_target_: ray.tune.schedulers.ASHAScheduler
max_t: ${train_func.epochs}
grace_period: 1
reduction_factor: 2
config:
wandb:
project: GRF_hip_outcomes
mode: disabled
optimizer:
lr:
_target_: ray.tune.loguniform
_args_:
- 0.0001
- 0.01
weight_decay:
_target_: ray.tune.loguniform
_args_:
- 0.001
- 0.01
data_loader:
batch_size:
_target_: ray.tune.choice
_args_:
- - 128
- 256
arch:
a:
_target_: ray.tune.randint
_args_:
- 1
- 30
raise_on_failed_trial: false
num_samples: 250
status: tune
n_gpu: 1
n_cpu: 10
num_workers: 4
resume: null
seed: 122
output_root: ./outputs
metrics:
- _target_: srcs.metrics.accuracy
monitor: val/accuracy
mode: max
topk_num: 5
the dummy model used:
from torch import nn
class DummyModel(nn.Module):
def __init__(self, a, output_size, channel_n=9):
super().__init__()
self.output_size=output_size
self.layers = nn.Sequential(nn.Linear(1,a),
nn.ReLU(inplace=True),
nn.Linear(a,output_size))
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = x.view(-1, 1)
out = self.layers(x)
return self.softmax(out)
the dummy dataloader used:
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
class1 = [2,4,6,8,10,12,14,16]
class2 = [3,6,9,12,15,18,21,24,27]
class3 = [1,5,7,11,13,17,19]
class DummyLoader(DataLoader):
def __init__(self, batch_size=50, validation_split=0.1, test_split=0.1, num_workers=4):
"""
dataset to identify whether number is multiple of 2 or 3 or something else
:param batch_size:
:param validation_split:
:param test_split:
:param num_workers:
"""
data_x = torch.tensor(class1+class2+class3, dtype=torch.float)
data_y = torch.tensor(len(class1)*[0]+len(class2)*[1]+len(class3)*[2])
self.dataset = TensorDataset(data_x, data_y)
self.n_samples = len(self.dataset)
self.train_set, self.val_set, self.test_set = self.__sample_split(validation_split, test_split)
self.loader_kwargs = {
'batch_size': batch_size,
'num_workers': num_workers
}
super().__init__(self.train_set, **self.loader_kwargs, shuffle=True)
def __sample_split(self, va_split, te_split):
val_n = int(self.n_samples * va_split)
test_n = int(self.n_samples * te_split)
train_n = self.n_samples - val_n - test_n
return random_split(self.dataset, [train_n, val_n, test_n])
def split_validation(self):
return DataLoader(self.val_set, shuffle=False,**self.loader_kwargs)
def split_test(self):
return DataLoader(self.test_set, shuffle=False, **self.loader_kwargs)
@property
def categories(self):
return [0,1,2]
@property
def channel_n(self):
return 1
when tested on partial GPU’s, it won’t end exactly on 98 trials anymore but ends around 100, the exact number seems to fluctuate. Regardless, it never reaches the total number of trials.
I’ve ran this on windows 11 machine with 32 gb of RAM, 1050ti, Ryzen 3700X