How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
Hi there, I’m wondering what the intended way to design a Ray Tune experiment with TensorFlow looks like when using the Trainable Class API for the tuning procedure.
When using the Class API, one focuses on the setup
and step
methods. The setup mentod is called once during initialization of the actor and the step
method is called training_iteration
times (assuming no stopping is used).
In TensorFlow, one generally calls fit
on the model (created in setup
) and specifies the number of training iterations through the epoch
argument. There’s two straightforward ways to integrate this in Ray Tune.
- Using the Class API, it would be possible to call
model.fit(epoch=100)
once in thestep
method and settingtraining_iterations
to 1. This way, we would runnum_samples
trials and for each trial, we would train the model for 100 epochs. Note that we go through a singletraining_iteration
in Ray Tune here. - Using the Class API, it would be possible to call
model.fit(epoch=1)
in thestep
method and settingtraining_iterations
to 100. This way, we would also runnum_samples
trials and for each trial, we would train the model for 100 epochs. Note that the epochs here are specified bytraining_iteration
.
Approach #1 works well as we fully train the network only once; however, many Ray Tune features (e.g., schedulers) seem to rely on training_iteration
as time attribute. Hence, this approach does not allow for easy integration with these additional features that rely on training_iteration
.
Therefore, I leaned more towards approach #2; however, this makes it more challenging to use the EarlyStopping callback that comes with TensorFlow. The reason for using this callback instead of the TrialPlateauStopper is that it returns the model with the best performance instead of the model of the most recent epoch. Furthermore, this callback specifies patience through improvement compared to previous trials whereas the TrialPlateauStopper uses the STD.
Naturally, it would be possible to create our own callback, but this entire approach of how to best integrate the TensorFlow fitting procedure with Ray Tune raises some questions that I figured some more-experienced users may be able to address. What is the cleanest way to integrate these two, especially when it comes to fitting the model in TensorFlow and the step
method in Ray Tune (when using the Class API)?