Checkpointing with distributed training

TL;DR: checkpointing with distributed training.

We are trying to run the example described here https://docs.ray.io/en/master/tune/examples/ddp_mnist_torch.html.
Everything with this was fine but it broke when we added keep_checkpoint_num in order to save disk space:

analysis = tune.run(
        trainable_cls,
        num_samples=4,
        stop={"training_iteration": 10},
        metric="mean_accuracy",
        mode="max",
        keep_checkpoint_num = 3)

this is the error we encountered:

OSError: [Errno 22] Invalid argument: '/home/my_user/ray_results/WrappedDistributedTorchTrainable_2021-03-29_09-29-22/WrappedDistributedTorchTrainable_3ef8a_00000_0_2021-03-29_09-29-22/checkpoint_000007/./'

The keep_checkpoint_num is the only parameter that we changed from the example.

Could you please help us solve this?

@kai @amogkam this looks like it could be a bug? Not sure, please take a look.

The argument in tune.run() should be “keep_checkpoints_num” not “keep_checkpoint_num”

1 Like

Hi,
apologies, it was just misreported here, I’ve just checked and I have the correct definition in the script.
Unfortunately, this does not solve the problem. Thanks anyway

hi, is there any news on this? I tried to debug it by myself but I cannot

Hi @riccardo_hum, sorry for the long delay.

It would be great if you could tell us which Ray version you are currently using. Also, if you have a longer stacktrace that would be helpful to determine what is happening here.

We had problems with checkpoint bookkeeping in the past, but most should have been resolved by now. I’m happy to look more into this if the problem persists with the latest master.

By the way, if you have a full training script I could just copy and paste and run on a cluster to get the same error, that would be most helpful. Thanks!

Hi @kai,

Thank you for your reply. I am using Ray version ‘2.0.0.dev0’ (installed with the command pip install - U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl) and here is the full stacktrace:

2021-04-01 15:00:31,528	ERROR trial_runner.py:894 -- Trial WrappedDistributedTorchTrainable_fa3b2_00002: Error handling checkpoint /home/my_user/ray_results/WrappedDistributedTorchTrainable_2021-04-01_15-00-20/WrappedDistributedTorchTrainable_fa3b2_00002_2_2021-04-01_15-00-20/checkpoint_000010/./
Traceback (most recent call last):
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial_runner.py", line 891, in _process_trial_save
    trial.on_checkpoint(trial.saving_to)
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial.py", line 525, in on_checkpoint
    self.checkpoint_manager.on_checkpoint(checkpoint)
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/checkpoint_manager.py", line 160, in on_checkpoint
    self.delete(worst)
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/site-packages/ray/tune/trial.py", line 106, in delete
    shutil.rmtree(checkpoint_dir)
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/shutil.py", line 722, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/my_user/miniconda3/envs/mdt_venv/lib/python3.8/shutil.py", line 720, in rmtree
    os.rmdir(path)
OSError: [Errno 22] Invalid argument: '/home/my_user/ray_results/WrappedDistributedTorchTrainable_2021-04-01_15-00-20/WrappedDistributedTorchTrainable_fa3b2_00002_2_2021-04-01_15-00-20/checkpoint_000001/./'

We get this error after simply adding the ‘keep_checkpoints_num’ parameter in the tune.run() call within your example script found in ddp_mnist_torch — Ray v2.0.0.dev0.
Below is the full script that we have used for completeness:

import argparse
import logging
import os
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel

import ray
from ray import tune
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
                                              ConvNet)
from ray.tune.integration.torch import (DistributedTrainableCreator,
                                        distributed_checkpoint_dir)

logger = logging.getLogger(__name__)


def train_mnist(config, checkpoint_dir=False):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = get_data_loaders()
    model = ConvNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    if checkpoint_dir:
        with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
            model_state, optimizer_state = torch.load(f)

        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    model = DistributedDataParallel(model)

    for epoch in range(40):
        train(model, optimizer, train_loader, device)
        acc = test(model, test_loader, device)

        if epoch % 3 == 0:
            with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
                path = os.path.join(checkpoint_dir, "checkpoint")
                torch.save((model.state_dict(), optimizer.state_dict()), path)
        tune.report(mean_accuracy=acc)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=2,
        help="Sets number of workers for training.")
    parser.add_argument(
        "--num-gpus-per-worker",
        action="store_true",
        default=False,
        help="enables CUDA training")
    parser.add_argument(
        "--cluster",
        action="store_true",
        default=False,
        help="enables multi-node tuning")
    parser.add_argument(
        "--workers-per-node",
        type=int,
        help="Forces workers to be colocated on machines if set.")

    args = parser.parse_args()

    if args.cluster:
        options = dict(address="auto")
    else:
        options = dict(num_cpus=2)
    ray.init(**options)
    trainable_cls = DistributedTrainableCreator(
        train_mnist,
        num_workers=1,
        num_cpus_per_worker=2,
        num_gpus_per_worker=args.num_gpus_per_worker,
        num_workers_per_host=1)

    analysis = tune.run(
        trainable_cls,
        num_samples=4,
        stop={"training_iteration": 10},
        metric="mean_accuracy",
        mode="max",
        keep_checkpoints_num=3
        )

    print("Best hyperparameters found were: ", analysis.best_config)

We have run this script locally on a machine with 4 CPUs and no GPUs (hence we just modified the available resources in the script accordingly).

hi @kai, have you seen my reply? Is there any other information that you need to check this behavior?

Thank you

Hi Riccardo, this should be enough to look into this. I’ll try to get to it sometime this week - thanks for your patience!

1 Like

Hi @kai, did you have time to try this? Sorry if I insist but it is really blocking my activity.

Thanks

Hi Riccardo,
I didn’t have the chance yet, but I put this on my list for tomorrow. Sorry about the delay.

Hi Riccardo,

I could reproduce this issue locally and filed a fix here [tune] Return normalized checkpoint path by krfricke · Pull Request #15296 · ray-project/ray · GitHub

We run this example on CI, so I’m not sure why this wasn’t caught before. Anyway, let me know if the fix works for you (it does for me). It will be merged to master sometime this week.

Best wishes

Thanks for your help, @kai,

actually, this partially solves the issue. Now there is no the error I mentioned before, the ./ error is solved, but:

  • in the experiment folder, the checkpoints folders are kept correctly (last 3 bests), but in the worker folder there are all the checkpoints and this will bring to the same memory issue for which I introduced keep_checkpoints_num;

  • If I kill the experiment whit ctrl+C and I try to resume it by running again with resume=True, I get the following:

2021-04-15 10:01:48,297 ERROR trial_runner.py:918 – Trial WrappedDistributedTorchTrainable_6d815_00000: Error processing restore.
Traceback (most recent call last):
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/tune/trial_runner.py”, line 912, in _process_trial_restore
self.trial_executor.fetch_result(trial)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/tune/ray_trial_executor.py”, line 678, in fetch_result
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/_private/client_mode_hook.py”, line 47, in wrapper
return func(*args, **kwargs)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/worker.py”, line 1440, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): ray::WrappedDistributedTorchTrainable.restore_from_object() (pid=9351, ip=10.132.0.11)
File “python/ray/_raylet.pyx”, line 501, in ray._raylet.execute_task
File “python/ray/_raylet.pyx”, line 444, in ray._raylet.execute_task.function_executor
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/_private/function_manager.py”, line 556, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/tune/trainable.py”, line 394, in restore_from_object
self.restore(checkpoint_path)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/tune/trainable.py”, line 372, in restore
self.load_checkpoint(checkpoint_path)
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/tune/integration/torch.py”, line 132, in load_checkpoint
return ray.get(
File “/home/ubuntu/anaconda3/envs/ray/lib/python3.8/site-packages/ray/_private/client_mode_hook.py”, line 47, in wrapper
return func(*args, **kwargs)
ValueError: ‘object_refs’ must either be an object ref or a list of object refs.

I have also added name='my_exp' to tune.run to be sure of resuming always the same experiment.

Hi @kai,

have you seen this?

Thank you,

Riccardo

Hi Riccardo,
sorry, I’ve been swamped with the upcoming release these past days.
Can you open an issue on GitHub for this? For actual bugs (and this seems like one) it makes it easier for us to keep track of and we can better prioritize between issues and assign them to the relevant code owners.
By worker folders, do you mean the checkpoints are kept on remote nodes and are only deleted on the driver/head node? If so that definitely sounds like a bug.