Thanks for your reply.
Here is my trainable class
class PytorchTrainble(tune.Trainable):
def setup(self, config):
if 'config' in config:
config = config['config']
if hasattr(self, 'initialized') and self.initialized:
self.device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config['dataset'] == 'ImageNet':
model = torchvision.models.resnet18()
del model.fc
model.fc = torch.nn.Linear(512, config['num_classes'])
elif config['dataset'] in ['CIFAR10', 'CIFAR100']:
model = resnet18(num_classes=100, ratio=config["width_ratio"])
# model = vgg16_bn()
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay = config["wd"])
self.scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
self.model = model
self.optimizer = optimizer
self.epoches = 50
self.result_record = Trajector("ResNet18", config["dataset"], str(config))
self.data_ratio_init = config['data_ratio']
self.cur_epoch = 0
self.config = config
self.max_acc = 0
self.sub_trainset, self.rest_trainset, self.testset = build_sub_rest_dataset(self.config)
self.proportion_list = get_expand_landmarks(self.sub_trainset, self.rest_trainset)
self.proportion = self.proportion_list[0]
self.abs_loss_dir = self.config['abs_loss_dir']
self.initialized = True
def call_twice(self, ):
def call_once(self, ):
def step(self):
# import pdb; pdb.set_trace()
ori_config = self.config
if 'config' in self.config:
self.config = self.config['config']
# data_ratio = self.cur_epoch / self.epoches * (1.0 - self.data_ratio_init) + self.data_ratio_init
# self.config['data_ratio'] = data_ratio
if 'data_ratio' not in self.config:
raise ValueError("data_ratio is to {} // {}".format(self.config, ori_config))
data_ratio = self.config['data_ratio']
# if data_ratio > 0.1:
# raise ValueError(f"this is not correct. {data_ratio}")
# train_loader, test_loader = build_dynamic_sub_dataloader(self.config)
if 'call_once' in self.config and self.config['call_once']:
return self.call_once()
if 'call_twice' in self.config and self.config['call_twice']:
return {"mean_accuracy": self.max_acc, 'not_increase': True}
if True:
# train_loader, test_loader = build_dataloader(self.config, self.sub_trainset, self.testset)
train_loader = DataLoader(self.sub_trainset, batch_size=min((int)(self.config["batch_size"]), len(self.sub_trainset)),
shuffle=True, drop_last=False, num_workers=4,
test_loader = DataLoader(self.testset, batch_size=256,
shuffle=True, drop_last=False, num_workers=4,
train(self.model, self.optimizer, train_loader)
acc_train, loss_train = test(self.model, train_loader)
acc, loss = test(self.model, test_loader)
self.cur_epoch += 1
# if self.cur_epoch == 2:
# raise ValueError('this is not correct in trial {} // {}'.format(self.config, ori_config))
self.max_acc = max(acc, self.max_acc)
return {"mean_accuracy": self.max_acc}
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
key_info = {
'model': self.model.state_dict(),
'optim': self.optimizer.state_dict(),
'cur_epoch': self.cur_epoch,
'scheduler': self.scheduler,
'result_record': self.result_record,
}, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_path):
key_info = torch.load(checkpoint_path)
self.scheduler = key_info['scheduler']
self.cur_epoch = key_info['cur_epoch']
self.result_record = key_info['result_record']
def reset_config(self, new_config):
if new_config['data_ratio'] > 0.1:
raise ValueError("this step is not correct in reset {}".format(new_config))
self.config = new_config
return True
I also modify HyperBand successive halving function (ray/ at master · ray-project/ray · GitHub)
def successive_halving(
self, metric: str, metric_op: float, trial_runner: "trial_runner.TrialRunner"
) -> Tuple[List[Trial], List[Trial]]:
if self._halves == 0 and not self.stop_last_trials:
return self._live_trials, []
acc_list = [self._live_trials[trial]['mean_accuracy'] for trial in self._live_trials]
if len(acc_list) > 1 and np.var(acc_list) < self.threshold:
# this incurs call_once function
self.after_call_once = True
for trial in self._live_trials:
trial.config['config']['data_ratio'] += 0.1
if not self.after_call_once and not self.after_call_twice and self.disable_progressive == False:
acc_list = [self._live_trials[trial]['mean_accuracy'] for trial in self._live_trials]
if len(acc_list) > 1 and np.var(acc_list) < self.threshold:
# this incurs call_once function
self.after_call_once = True
for trial in self._live_trials:
trial.config['config']['data_ratio'] += 0.1
trial.config['config']['call_once'] = True
trial.config['config']['trial_id'] = self._live_trials[trial]['trial_id']
self._live_trials[trial][self._time_attr] -= 1
print('processing trial id {}'.format(trial.config['config']['trial_id']))
print(self._live_trials[trial][self._time_attr], flush=True)
if True:
self.update_trial_config(trial, trial_runner)
return self._live_trials, []
What I want to do is to set call_once
and call_twice
in config via update_trial_config
. Then, when I execute step
function, it knows to execute which function, e.g., train model, call_once operation, call_twice operation. However, i find that when I modify the config in successive halving, the setup
function will be called. This is out of my expectation.