I’m using the Trainable
API to implement a modified version of PBT in Python. I currently hard-code the number of epochs per call to step
within the Trainable
source code (epochs_per_generation
). However, I’d like to be able to control this in a separate run script, which I already use to define all other run parameters compactly. I thought of using the config
that Trainable.setup
takes as input, but this isn’t quite the right place because I use it for passing model-specific parameters. Since Ray creates the class instances, I was trying to do it by creating a default static variable in the class definition (Subclass.epochs_per_generation
) and overwriting it in the run script before passing the class into tune.run
. This works outside of Ray based on my simple tests of updating the static variable attached to the class and then creating an instance using the updated class. However, when I try running the updated Trainable
in Ray the pre-update value of the static variable is being used. Is there anything Ray is doing that would prevent updating static variables of Trainable
subclasses to influence behavior of instances like this?
Hey @arsedler9 can you provide an example of what you’re trying to do? I think the problem is that Ray serializes the class definition and not necessarily the class assigned variables.
Sure! Here’s a really basic example that captures the gist of it. Let me know if you need any more details. I thought it was odd because I traced through the Ray source for a little ways and found that when I serialized and then immediately deserialized here, I saw the correct (updated) epochs_per_generation
.
# In trainable source code
class TrainableModel(tune.Trainable):
epochs_per_generation = 50 # default
def step(self):
for i in range(self.epochs_per_generation):
metrics = self.model.train_epoch()
# In run script
TrainableModel.epochs_per_generation = 25 # overwrite
tune.run(TrainableModel, ...)
Got it; I would try something like:
def create_trainable(epochs_per_generation):
class NewTrainableModel(TrainableModel):
epochs_per_generation = epochs_per_generation
return NewTrainableModel
tune.run(create_trainable(N), ...)
Thanks! I tried that and it seems to work but I had to add a workaround due to scope issues with defining a class inside a function. Not the prettiest, but I do like it better than the hard-coding.
def create_trainable(epochs_per_generation):
global global_epg
global_epg = epochs_per_generation
class NewTrainableModel(TrainableModel):
epochs_per_generation = global_epg
return NewTrainableModel