Hi @kai,
Thank you for your reply. I am using Ray version ‘2.0.0.dev0’ (installed with the command pip install - U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
) and here is the full stacktrace:
2021-04-01 15:00:31,528 ERROR trial_runner.py:894 -- Trial WrappedDistributedTorchTrainable_fa3b2_00002: Error handling checkpoint /home/my_user/ray_results/WrappedDistributedTorchTrainable_2021-04-01_15-00-20/WrappedDistributedTorchTrainable_fa3b2_00002_2_2021-04-01_15-00-20/checkpoint_000010/./
Traceback (most recent call last):
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial_runner.py", line 891, in _process_trial_save
trial.on_checkpoint(trial.saving_to)
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial.py", line 525, in on_checkpoint
self.checkpoint_manager.on_checkpoint(checkpoint)
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/checkpoint_manager.py", line 160, in on_checkpoint
self.delete(worst)
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial.py", line 106, in delete
shutil.rmtree(checkpoint_dir)
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/shutil.py", line 722, in rmtree
onerror(os.rmdir, path, sys.exc_info())
File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/shutil.py", line 720, in rmtree
os.rmdir(path)
OSError: [Errno 22] Invalid argument: '/home/my_user/ray_results/WrappedDistributedTorchTrainable_2021-04-01_15-00-20/WrappedDistributedTorchTrainable_fa3b2_00002_2_2021-04-01_15-00-20/checkpoint_000001/./'
We get this error after simply adding the ‘keep_checkpoints_num’ parameter in the tune.run() call within your example script found in ddp_mnist_torch — Ray v2.0.0.dev0.
Below is the full script that we have used for completeness:
import argparse
import logging
import os
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
import ray
from ray import tune
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
ConvNet)
from ray.tune.integration.torch import (DistributedTrainableCreator,
distributed_checkpoint_dir)
logger = logging.getLogger(__name__)
def train_mnist(config, checkpoint_dir=False):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = ConvNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(40):
train(model, optimizer, train_loader, device)
acc = test(model, test_loader, device)
if epoch % 3 == 0:
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((model.state_dict(), optimizer.state_dict()), path)
tune.report(mean_accuracy=acc)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.")
parser.add_argument(
"--num-gpus-per-worker",
action="store_true",
default=False,
help="enables CUDA training")
parser.add_argument(
"--cluster",
action="store_true",
default=False,
help="enables multi-node tuning")
parser.add_argument(
"--workers-per-node",
type=int,
help="Forces workers to be colocated on machines if set.")
args = parser.parse_args()
if args.cluster:
options = dict(address="auto")
else:
options = dict(num_cpus=2)
ray.init(**options)
trainable_cls = DistributedTrainableCreator(
train_mnist,
num_workers=1,
num_cpus_per_worker=2,
num_gpus_per_worker=args.num_gpus_per_worker,
num_workers_per_host=1)
analysis = tune.run(
trainable_cls,
num_samples=4,
stop={"training_iteration": 10},
metric="mean_accuracy",
mode="max",
keep_checkpoints_num=3
)
print("Best hyperparameters found were: ", analysis.best_config)
We have run this script locally on a machine with 4 CPUs and no GPUs (hence we just modified the available resources in the script accordingly).