Change the config in tune.scheduler will call the setup function of Trainable class

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Here is a function to update the trial’s config

    def update_trial_config(self, trial, trial_runner): 
        trainable = trial.runner
        if trainable is None: 
        DEFAULT_GET_TIMEOUT = 60.0  # seconds
        new_config = trial.config 
        with trial_runner.trial_executor._change_working_directory(trial):
            with warn_if_slow("reset"):
                    reset_val = ray.get(
                except GetTimeoutError:
                    logger.exception("Trial %s: reset timed out.", trial)

When I call a function to change the config of a trial, I find that it will call the setup function (ray/ at master · ray-project/ray · GitHub). This will reinitialize the model, optimizer and etc. How could I change the config without incurring the setup function.

What is that you are trying to achieve?
What kind of trial_config you want to update?
Can you share you trainable/training function code?

Thanks for your reply.
Here is my trainable class

class PytorchTrainble(tune.Trainable):
    def setup(self, config):
        if 'config' in config: 
            config = config['config']
        if hasattr(self, 'initialized') and self.initialized: 
        self.device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if config['dataset'] == 'ImageNet': 
            model = torchvision.models.resnet18()
            del model.fc 
            model.fc = torch.nn.Linear(512, config['num_classes'])
        elif config['dataset'] in ['CIFAR10', 'CIFAR100']: 
            model = resnet18(num_classes=100, ratio=config["width_ratio"]) 
        # model = vgg16_bn()
        optimizer = optim.SGD(
            model.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay = config["wd"])
        self.scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
        self.model = model 
        self.optimizer = optimizer 
        self.epoches = 50
        self.result_record = Trajector("ResNet18", config["dataset"], str(config))
        self.data_ratio_init = config['data_ratio']
        self.cur_epoch = 0
        self.config = config 
        self.max_acc = 0 
        self.sub_trainset, self.rest_trainset, self.testset = build_sub_rest_dataset(self.config)
        self.proportion_list = get_expand_landmarks(self.sub_trainset, self.rest_trainset)
        self.proportion = self.proportion_list[0]
        self.abs_loss_dir = self.config['abs_loss_dir']
        self.initialized = True 

    def call_twice(self, ): 
    def call_once(self, ): 
    def step(self):
        # import pdb; pdb.set_trace() 
        ori_config = self.config
        if 'config' in self.config: 
            self.config = self.config['config']
        # data_ratio = self.cur_epoch / self.epoches * (1.0 - self.data_ratio_init) + self.data_ratio_init 
        # self.config['data_ratio'] = data_ratio
        if 'data_ratio' not in self.config: 
            raise ValueError("data_ratio is to {} // {}".format(self.config, ori_config))
        data_ratio = self.config['data_ratio']
        # if data_ratio > 0.1: 
        #     raise ValueError(f"this is not correct. {data_ratio}")
        # train_loader, test_loader = build_dynamic_sub_dataloader(self.config)
        if 'call_once' in self.config and self.config['call_once']: 
            return self.call_once() 
        if 'call_twice' in self.config  and self.config['call_twice']: 
            return {"mean_accuracy": self.max_acc, 'not_increase': True}
        if True: 
            # train_loader, test_loader = build_dataloader(self.config, self.sub_trainset, self.testset)
            train_loader = DataLoader(self.sub_trainset, batch_size=min((int)(self.config["batch_size"]), len(self.sub_trainset)),
                                            shuffle=True, drop_last=False, num_workers=4,
            test_loader = DataLoader(self.testset, batch_size=256,
                                            shuffle=True, drop_last=False, num_workers=4,
            train(self.model, self.optimizer, train_loader)
            acc_train, loss_train = test(self.model, train_loader)
            acc, loss = test(self.model, test_loader)

        self.cur_epoch += 1
        # if self.cur_epoch == 2: 
        #     raise ValueError('this is not correct in trial {} // {}'.format(self.config, ori_config))
        self.max_acc = max(acc, self.max_acc)
        return {"mean_accuracy": self.max_acc}

    def save_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        key_info = {
            'model': self.model.state_dict(), 
            'optim': self.optimizer.state_dict(), 
            'cur_epoch': self.cur_epoch,
            'scheduler': self.scheduler,
            'result_record': self.result_record,
        }, checkpoint_path)
        return checkpoint_path

    def load_checkpoint(self, checkpoint_path):
        key_info = torch.load(checkpoint_path)
        self.scheduler = key_info['scheduler']
        self.cur_epoch = key_info['cur_epoch']
        self.result_record = key_info['result_record']

    def reset_config(self, new_config):
        if new_config['data_ratio'] > 0.1: 
            raise ValueError("this step is not correct in reset {}".format(new_config))
        self.config = new_config
        return True

I also modify HyperBand successive halving function (ray/ at master · ray-project/ray · GitHub)

    def successive_halving(
        self, metric: str, metric_op: float, trial_runner: "trial_runner.TrialRunner" 
    ) -> Tuple[List[Trial], List[Trial]]:
        if self._halves == 0 and not self.stop_last_trials:
            return self._live_trials, []

        acc_list = [self._live_trials[trial]['mean_accuracy'] for trial in self._live_trials]
        if len(acc_list) > 1 and np.var(acc_list) < self.threshold: 
            # this incurs call_once function
            self.after_call_once = True 
            for trial in self._live_trials: 
                trial.config['config']['data_ratio'] += 0.1
        if not self.after_call_once and not self.after_call_twice and self.disable_progressive == False: 
            acc_list = [self._live_trials[trial]['mean_accuracy'] for trial in self._live_trials]
            if len(acc_list) > 1 and np.var(acc_list) < self.threshold: 
                # this incurs call_once function
                self.after_call_once = True 
                for trial in self._live_trials: 
                    trial.config['config']['data_ratio'] += 0.1
                    trial.config['config']['call_once'] = True 
                    trial.config['config']['trial_id'] = self._live_trials[trial]['trial_id']
                    self._live_trials[trial][self._time_attr] -= 1
                    print('processing trial id {}'.format(trial.config['config']['trial_id']))
                    print(self._live_trials[trial][self._time_attr], flush=True)
                    if True: 
                        self.update_trial_config(trial, trial_runner)
                return self._live_trials, []

What I want to do is to set call_once and call_twice in config via update_trial_config. Then, when I execute step function, it knows to execute which function, e.g., train model, call_once operation, call_twice operation. However, i find that when I modify the config in successive halving, the setup function will be called. This is out of my expectation.

Hi, I want to make my question simpler. The trial config will change based on certain condition. The condition happens at successive halving. However, I find that when I modify the trial.config via set_config or anything else where trial stays in trial.PAUSED state, it might incur RayTrialExecutor to call setup to reinitialize the trial. Cloud you provide.a more elegant way to update the trial config when the trial stays in the trial.PAUSED state.

Hi @gaow0007, could you try using:

tuner = Tuner(..., tune_config=tune.TuneConfig(reuse_actors=True))

Then, it will not reinitialize the actor and only perform reset_config on an existing actor to unpause.

See here for more info: Training in Tune (tune.Trainable, — Ray 2.3.0