How to make accuracy and loss plots per epoch from a checkpointed ray tune model

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I ran ray tune to identify the model with optimal hyperparams, and I have that saved to a file.

I want to read in that checkpointed model, and draw the accuracy and loss plots for training and validation over epochs, for that best model.

I can work out how to use ray tune for HPO and save the best model, and how to read the best model back in, but I’m stuck on the last part.

Is there a simple way (it’s my first time using both ray tune and pytorch) for me to add in ‘make accuracy and loss plots of training’ to the checkpointed model at some point? If it matters, I used pytorch lightning to train a graph-based model with pytorch geometric.

The code:

def train_fn(config):
    train_graph_classifier(
        model_name="GraphConv", 
        layer_name="GraphConv",
        **config)

analysis = tune.run(
    train_fn,
    config={
        "c_hidden": tune.choice([64, 128,256]),
        "dp_rate_linear": tune.choice([0.1,0.3,0.5]), #could change to quniform and give a 3-point tuple range
        "num_layers":tune.choice([3,4,5,6]),
        "dp_rate":tune.choice([0.1,0.3,0.5,0.7])

    },
    local_dir='/home/test_predictor/ray_ckpt2',  # path for saving checkpoints
    metric="val_loss",
    mode="min",
    num_samples=16,
    scheduler=scheduler_asha,
    progress_reporter=reporter,
    name="test")


#find best model
path = analysis.best_checkpoint + '/' + "ray_ckpt"
model = GraphLevelGNN.load_from_checkpoint(path)

....how do I say 'plot model training and val accuracy and loss per epoch here'

Thanks.

Hey @SlowatKela1- this should be possible via the Ray Tune analysis API which returns a pandas dataframe containing all the metrics that have been reported by all trials. Then you can plot this dataframe using any visualization tool you like (for example matplotlib).

If you are using the new Ray 2.0 Tuner API- you would do this through the ResultGrid.get_dataframe() API (ResultGrid (tune.ResultGrid) — Ray 2.0.0).

If you are using the tune.run API from Ray 1.13 or prior, you can do this via the ExperimentAnalysis API (Analysis (tune.analysis) — Ray 1.13.0).

1 Like