Save and reuse Checkpoints in Ray 2.0 version

High: It blocks me to complete my task.

I want to achieve following two steps while training a Tensorflow model using Ray train:

  1. Save the checkpoint with model weights in a pickle file
  2. Use this saved checkpoint pickle file later on to load the model with model weights.
    (For prediction or retraining the model)

In earlier version of Ray first step could be done with below code

train.save_checkpoint(epoch=epoch, accuracy=history.history['acc'][0],

this would save the checkpoints for each epoch in a folder named 'run_001/checkpoints/" as shown below


The 2nd step could be done with code

checkpoint = pickle.load(open(f"run_001/checkpoints/checkpoint_000002/dict_checkpoint.pkl", "rb"))

This would load model with the weights from checkpoint of 2nd epoch.

In latest version of Ray >= 2.0, how can I implement these steps ?

As per my understanding, the checkpoints are created using{},checkpoint=Checkpoint.from_dict(dict(epoch=epoch,

But this doesn’t save any pickle file as shown above.

Please guide me through this and help with solution.

This is provided now in TensorflowCheckpoint class now. Ray AIR API — Ray 2.0.1

Not only can you save weights, you can also save and load saved_model and h5 formats as well.

Thanks for the response @xwjiang2010.
But for my use case, I need to save the checkpoints as a file in my local storage not the model. And use this saved file later on to load the model.
I didn’t find any function for that in the documentation.
Is there any way to do that ?

Do you mean you want to load the model inside of a Ray AIR or you want to load the model in your own function?

To save and load a model in a Ray Trainer, one can use, checkpoint=ckpt) and session.get_checkpoint().

To load it outside of Ray AIR, one may do it like:
ckpt = Checkpoint.from_directory(directory_path); my_dict = ckpt.to_dict(); model_weights = my_dict["model_weights"].

Hi @xwjiang2010 , I saw this was just released. I am trying to get it to work where I checkpoint on every epoch. I’ve been using the TensorflowCheckpoint.from_model(model) call within a LambdaCallback and passing that within my .fit() function but have not been able to get it to work.

What I’d like to see is an example where (1) checkpoint model on every epoch , and (2) checkpoint model on a frequency . Do you have examples like that ?

This is the train function I am testing. Based on the examples you guys provide.

def train_func(config):
    per_worker_batch_size = config.get("batch_size", 64)
    epochs = config.get("epochs", 3)
    steps_per_epoch = config.get("steps_per_epoch", 70)
    tf_config = json.loads(os.environ["TF_CONFIG"])
    num_workers = len(tf_config["cluster"]["worker"])
    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    global_batch_size = per_worker_batch_size * num_workers
    multi_worker_dataset = mnist_dataset(global_batch_size)

    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_cnn_model()

        learning_rate = config.get("lr", 0.001)

    custom_cb_checkpoint_model = LambdaCallback(
        on_epoch_end=lambda epoch, logs:{},checkpoint=TensorflowCheckpoint.from_model(multi_worker_model))
    history =
    results = history.history
    return results


I see. The documentation update is not included in 2.1 release. But you can still find it here: Convert existing Tensorflow/Keras code to Ray AIR — Ray 3.0.0.dev0

Please use this provided Callback (from ray.air.integrations.keras import Callback) together with, as shown in this example.

Hi @xwjiang2010 . Apologies, but I don’t see how to create a TensorflowCheckpoint of my model in that example you posted .

My goal is to do what you’ve done here:

(from example)

   for epoch in range(EPOCH):            
        # This will make sure that the training workers will get their own
        # share of batch to work on.
        # See `ray.train.tensorflow.prepare_dataset_shard` for more information.
        tf_dataset = to_tf_dataset(dataset=dataset_shard, batch_size=BATCH_SIZE), verbose=0)
        # This saves checkpoint in a way that can be used by Ray Serve coherently.

But be able to pass a callback rather than iterate over epochs. Something like this:

model_checkpoint = TensorflowCheckpoint.from_model(multi_worker_model)

custom_cb_checkpoint_model = LambdaCallback(
    on_epoch_end=lambda epoch, logs:{"my_metric": 1},checkpoint=model_checkpoint)

history =

However, that is not working.

How do I wrap my to correctly call that checkpoint?

I am also interested to use TensorflowCheckpoint.from_h5("my_model.h5") . To save entire model . If you have an example.


Hi @xwjiang2010 ,

Yes, I want to load the model in other custom function.

I understand that we can load checkpoint outside of Ray using

ckpt = Checkpoint.from_directory(directory_path); my_dict = ckpt.to_dict(); model_weights = my_dict["model_weights"]

My issue is with saving the checkpoint on each epoch in my local folder.
To save the checkpoint I used method checkpoint.to_directory("folder_path") but it doesn’t seem to be working. I didn’t get anything in “folder_path” location.

Please let me know if I am missing anything here

Hi @suraj-gade,

One thing to check for checkpoint.to_directory("folder_path"): is the folder path being specified as an absolute path? Tune will change the working directory by default to the trial directory (or the worker directory under the trial directory if using Train), so you’ll need to find it under the experiment log directory (ex: ~/ray_results/experiment-name/trial-dir/).

Another way to access the checkpoints from another script is to get them from the experiment results:

results = Tuner.restore(path).get_results()
results[0].best_checkpoints. # [(checkpoint, metrics), ...]

See Analyzing Tune Experiment Results — Ray 3.0.0.dev0 for more info.

Hi @max_ronda,

This definitely needs to be surfaced better in the docs, but the AIR Keras Callback creates a TensorflowCheckpoint and uses to report the checkpoints at some configurable frequency. I think this is the functionality you are looking for?

When you try to define the LambdaCallback, what is the error that you are seeing? The Keras Callback above is doing something similar (wrapping in the callback handlers).