Hi! I’m using the tune to find the hyper-parameters in my model. I just want to train the model with the same initialized parameters, and I want to reuse the actors to speed up my training process. I also use the wandb to visualize the every trial. I just implement my code as follows:
from ray.tune import Trainable
from ray.tune.integration.wandb import WandbTrainableMixin
from copy import deepcopy
import wandb
class My_Train(WandbTrainableMixin, Trainable):
def setup(self, config):
# I hope to restart the model on the same model parameters
# It's time consuming to re-initialize the model
# So I save the model state dict in self.model_state_dict
self.model = ...
self.model_state_dict = deepcopy(self.model.state_dict())
self.hyper_params = config['hyper_params']
self.epoch = 0
def step(self):
avg_loss = 0.0
for batch_idx, batch in enumerate(self.dataloader):
loss = self.model(batch)
avg_loss = (avg_loss * batch_idx + loss) / (batch_idx + 1)
wandb.log({'loss': avg_loss, 'epoch':self.epoch})
def reset_config(self, new_config):
# In order to restart the same initialized model with a different config
self.hyper_params = new_config['hyper_params']
self.model.load_state_dict(self.model_state_dict)
self.epoch = 0
But I found that the latter trial’s {‘loss’: avg_loss, ‘epoch’:self.epoch} will overwrite the former trial’s rather than create a new wandb init.
I know I can solve this problem by not using the reuse_actors method. But I really want to know whether there is a method to solve this problem, thanks a lot!!