[Tune] There is a need to add resetup method in trainable class?

Hi! I’m using the tune to find the hyper-parameters in my model. I just want to train the model with the same initialized parameters, and I want to reuse the actors to speed up my training process. I also use the wandb to visualize the every trial. I just implement my code as follows:

from ray.tune import Trainable
from ray.tune.integration.wandb import WandbTrainableMixin
from copy import deepcopy
import wandb


class My_Train(WandbTrainableMixin, Trainable):
    def setup(self, config):
        # I hope to restart the model on the same model parameters
        # It's time consuming to re-initialize the model
        # So I save the model state dict in self.model_state_dict

        self.model = ...
        self.model_state_dict = deepcopy(self.model.state_dict())

        self.hyper_params = config['hyper_params']
        self.epoch = 0

    def step(self):
        avg_loss = 0.0
        for batch_idx, batch in enumerate(self.dataloader):
            loss = self.model(batch)
            avg_loss = (avg_loss * batch_idx + loss) / (batch_idx + 1)

        wandb.log({'loss': avg_loss, 'epoch':self.epoch})

    def reset_config(self, new_config):
        # In order to restart the same initialized model with a different config
        self.hyper_params = new_config['hyper_params']
        self.model.load_state_dict(self.model_state_dict)
        self.epoch = 0

But I found that the latter trial’s {‘loss’: avg_loss, ‘epoch’:self.epoch} will overwrite the former trial’s rather than create a new wandb init.

I know I can solve this problem by not using the reuse_actors method. But I really want to know whether there is a method to solve this problem, thanks a lot!!

Are you using Wandb just to log the trial results? If so, you might be able to just use a WandbLoggerCallback instead of the mixin.

See External library integrations (tune.integration) — Ray v2.0.0.dev0

The mixin calls wandb.init when the trainable is constructed - however, you should be able to re-initialize wandb in your reset_config method.

self.wandb = self._wandb.init(**wandb_init_kwargs)

see the source code here: ray.tune.integration.wandb — Ray v2.0.0.dev0