Using static variables to control Trainable subclass in ray.tune

I’m using the Trainable API to implement a modified version of PBT in Python. I currently hard-code the number of epochs per call to step within the Trainable source code (epochs_per_generation). However, I’d like to be able to control this in a separate run script, which I already use to define all other run parameters compactly. I thought of using the config that Trainable.setup takes as input, but this isn’t quite the right place because I use it for passing model-specific parameters. Since Ray creates the class instances, I was trying to do it by creating a default static variable in the class definition (Subclass.epochs_per_generation) and overwriting it in the run script before passing the class into tune.run. This works outside of Ray based on my simple tests of updating the static variable attached to the class and then creating an instance using the updated class. However, when I try running the updated Trainable in Ray the pre-update value of the static variable is being used. Is there anything Ray is doing that would prevent updating static variables of Trainable subclasses to influence behavior of instances like this?

Hey @arsedler9 can you provide an example of what you’re trying to do? I think the problem is that Ray serializes the class definition and not necessarily the class assigned variables.

Sure! Here’s a really basic example that captures the gist of it. Let me know if you need any more details. I thought it was odd because I traced through the Ray source for a little ways and found that when I serialized and then immediately deserialized here, I saw the correct (updated) epochs_per_generation.

# In trainable source code
class TrainableModel(tune.Trainable):
    epochs_per_generation = 50 # default

    def step(self):
        for i in range(self.epochs_per_generation):
            metrics = self.model.train_epoch()

# In run script
TrainableModel.epochs_per_generation = 25 # overwrite
tune.run(TrainableModel, ...)

Got it; I would try something like:

def create_trainable(epochs_per_generation):
     class NewTrainableModel(TrainableModel):
         epochs_per_generation = epochs_per_generation
     return NewTrainableModel

tune.run(create_trainable(N), ...)

Thanks! I tried that and it seems to work but I had to add a workaround due to scope issues with defining a class inside a function. Not the prettiest, but I do like it better than the hard-coding.

def create_trainable(epochs_per_generation):
    global global_epg
    global_epg = epochs_per_generation
    class NewTrainableModel(TrainableModel):
        epochs_per_generation = global_epg
    return NewTrainableModel