Cleanest integration of TensorFlow `model.fit()` in the Ray Tune Class API `step` method

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Hi there, I’m wondering what the intended way to design a Ray Tune experiment with TensorFlow looks like when using the Trainable Class API for the tuning procedure.

When using the Class API, one focuses on the setup and step methods. The setup mentod is called once during initialization of the actor and the step method is called training_iteration times (assuming no stopping is used).

In TensorFlow, one generally calls fit on the model (created in setup) and specifies the number of training iterations through the epoch argument. There’s two straightforward ways to integrate this in Ray Tune.

  1. Using the Class API, it would be possible to call model.fit(epoch=100) once in the step method and setting training_iterations to 1. This way, we would run num_samples trials and for each trial, we would train the model for 100 epochs. Note that we go through a single training_iteration in Ray Tune here.
  2. Using the Class API, it would be possible to call model.fit(epoch=1) in the step method and setting training_iterations to 100. This way, we would also run num_samples trials and for each trial, we would train the model for 100 epochs. Note that the epochs here are specified by training_iteration.

Approach #1 works well as we fully train the network only once; however, many Ray Tune features (e.g., schedulers) seem to rely on training_iteration as time attribute. Hence, this approach does not allow for easy integration with these additional features that rely on training_iteration.

Therefore, I leaned more towards approach #2; however, this makes it more challenging to use the EarlyStopping callback that comes with TensorFlow. The reason for using this callback instead of the TrialPlateauStopper is that it returns the model with the best performance instead of the model of the most recent epoch. Furthermore, this callback specifies patience through improvement compared to previous trials whereas the TrialPlateauStopper uses the STD.

Naturally, it would be possible to create our own callback, but this entire approach of how to best integrate the TensorFlow fitting procedure with Ray Tune raises some questions that I figured some more-experienced users may be able to address. What is the cleanest way to integrate these two, especially when it comes to fitting the model in TensorFlow and the step method in Ray Tune (when using the Class API)?

Hey @Cysto- this is a great question. Using Tensorflow/Keras model.fit() does not really work well with the Class API. It is much more simple with Tune’s function API and using our Keras Callback like in this example: Using Keras & TensorFlow with Tune — Ray 2.0.0

Is there a reason you need to use the Class API?

Hi @amogkam,

Thank you for your fast response. There’s no necessity for me to use the Class API, it’s a personal preference (feels a bit cleaner/more organized to me personally).

If I understand correctly, you’d suggest those that use TensorFlow to use the Functional API? If so, that’s probably what I’ll do.

When using the Functional API, what would be the easiest way to add metrics to the Ray trial at different points in time? For instance, what would be the easiest way to add additional metrics after I have trained the model (i.e., call .fit()) with the Keras Callback?

I tried doing it by calling returning the additional metrics at the end of the method; however, that does not actually add the metrics to the trials (most likely because I’ve already used the Keras callback?). I’ve also tried calling session.report(additional_metrics) after computing the additional metrics, but I noticed that this overrides the metrics and then I have to somehow figure out the initial metrics again.

Hence, my question is: Is there a straightforward way in which I can report metrics at different points in time for the same training iteration (and trial)? Preferably, once through the Keras Callback and once with a simple dictionary that contains my metrics computed after training?

Hey @Cysto, is the additional metrics only needed at the end of training? If so, I would just recommend implementing your own Keras callback rather than using our built in one.

It can be something like this:

from tensorflow.keras.callbacks import Callback
from ray import tune

class MyTuneCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Don't report on last epoch
        # It will get logged in on_train_end instead
        if self.current_epoch < NUM_EPOCHS:
            tune.report(logs)
        self.current_epoch += 1

    def on_train_end(self, logs=None):
        logs = process_logs(logs) # Calculate additional metrics here
        tune.report(logs)