Access ray train checkpoint after training

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I am working on using TorchTrainer to train models from transformers library. I am able to checkpoint the training progress and also the path of the final ray.train.Result object by logging result.path in the end. But is there a way to obtain result.path only from the experiment name and storage path?

For example, I am using the below script for training:

import torch
import tempfile
from transformers import RobertaForMaskedLM, RobertaConfig
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.data.data_collator import DataCollatorForLanguageModeling
import ray
from ray import train
import ray.train.torch
from ray.train.torch import TorchTrainer
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig


def train_loop_per_worker(config):
    device = ray.train.torch.get_device()
    world_rank = train.get_context().get_world_rank()
    tokenizer = RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base")

    model = RobertaForMaskedLM(RobertaConfig(vocab_size=tokenizer.vocab_size))
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
    batch_size = 16
    model = ray.train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
                                   
    train_data_shard = train.get_dataset_shard("train")
    train_dataloader = train_data_shard.iter_batches(batch_size=batch_size)

    start_epoch = 0
    for epoch in range(start_epoch, config["num_epochs"]):
        for i, batch in enumerate(train_dataloader):
            inputs = batch['text']
            tokens = tokenizer(inputs.tolist(),
                               padding=True,
                               return_tensors="pt")
            inputs, labels = data_collator.torch_mask_tokens(
                tokens['input_ids'])
            inputs = {
                'input_ids': inputs.to(device),
                'labels': labels.to(device),
                'attention_mask': tokens['attention_mask'].to(device),
            }
            outputs = model(**inputs)
            loss = outputs.get("loss")
            loss.backward()
            optimizer.zero_grad()
            optimizer.step()
            loss = loss.detach().cpu().item()

            if i % 10 == 0 and world_rank == 0:
                metrics = {"loss": loss, "epoch": epoch, "iteration": i}
                with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                    data = {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch,
                        'loss': loss
                    }
                    torch.save(data, "ckpt.pt")
                    ray.train.report(metrics,
                                     checkpoint=Checkpoint.from_directory(
                                         temp_checkpoint_dir))

                print ('in worker ', world_rank, ' at iteration ', i)

if __name__ == '__main__':
    use_gpu = True
    num_workers = 1 

    train_dataset = ray.data.from_items([{'text': 'this is a long line' * 10}] * 320)
    train_dataset.materialize()

    train_loop_config = {"num_epochs": 20}
    ckpt_config = CheckpointConfig(checkpoint_score_attribute='loss',
                                   checkpoint_score_order='min')
    run_config = RunConfig(
        checkpoint_config=ckpt_config,
        name='test',
        storage_path='s3://<bucket-name>/tr-plain/')
    trainer = TorchTrainer(train_loop_per_worker=train_loop_per_worker,
                           train_loop_config=train_loop_config,
                           datasets={"train": train_dataset},
                           scaling_config=ScalingConfig(
                               num_workers=num_workers, use_gpu=use_gpu),
                           run_config=run_config)
    result = trainer.fit()

   # result.path is printed here
    print ('result path is ', result.path)

The result.path in above is printed as <bucket-name>/tr-plain/test/TorchTrainer_84119_00000_0_2024-03-03_11-02-24 . The annoying part here is to the need to know about TorchTrainer_84119_00000_0_2024-03-03_11-02-24 when there are multiple such checkpoints for inspecting the result offline (after the training) for analysis. Is there a way to get ray.train.Result only using the storage_path and name specified in the ray.train.RunConfig?

@bveeramani apologies for tagging, but is there any solution for this?

@arunppsg Hey, there’s currently not a very nice way to do this due to the extra nested trial directory.

Here’s a workaround that you could use for now:

from ray.tune import ResultGrid
from ray.tune.analysis import ExperimentAnalysis

exp_dir = os.path.join(storage_path, name)
# Assuming 1 "trial" if you're using Ray Train
result = ResultGrid(ExperimentAnalysis(exp_dir))[0]

# Use result
result.checkpoint
1 Like