[Train] How to use TensorflowCheckpoint to save model/h5

Hi @xwjiang2010 , I realized I hijacked this conversation to ask how to use the newly released support for TensorflowCheckpoint. Asking here to address it better and not disturb other users question.

Curious if you can provide an example that allows me to call TensorflowCheckpoint.from_h5("my_model.h5") within a callback using session api ? I basically would like to save my checkpointed model on each epoch.

Like I mentioned in other conversation, I have something like this

multi_worker_model.save("my_model.h5")
model_checkpoint = TensorflowCheckpoint.from_h5("my_model.h5")
custom_cb_checkpoint_model = LambdaCallback(
    on_epoch_end=lambda epoch, logs: session.report({},checkpoint=model_checkpoint)
)

history = multi_worker_model.fit(
    multi_worker_dataset,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    callbacks=[
        custom_cb_checkpoint_model
    ],
    verbose=0,
)

But that is not working.

I followed your example and also this example and still not clear. If you can help, please!

Thanks