Hello I’m trying to use Ray-Train to train my pytorch model using a multi-node setup.
If I use a smoke-test with very small sample size, the model trains. But if I try on the full set,
The training stops with the following error:
RuntimeError: Some workers returned results while others didn’t. Make sure that train.report()
and train.checkpoint()
are called the same number of times on all workers.
I was using train.checkpoint(), I stopped because I saw another person with this problem. But the error persists. If I stop using train.report(), then my logging will be a mess.
Other things I tried:
torch.distributed.barrier() (problem persists):
my code:
def train_model(config):
loss_fn = nn.BCEWithLogitsLoss
model = IntegrativeModel(# config,
warm=False)
prot_transformer = TransformerProt()
prot_transformer = ray.get(config["trans_id"])
model.prot_transformer = prot_transformer.transformer #loading pretrained
model.set_cold()
model = train.torch.prepare_model(model)
params_to_optimize = [
{'params': model.parameters()}
]
optim = torch.optim.Adam(
params_to_optimize, lr=config["lr"], weight_decay=config["weight_decay"])
# loading and preparing dataloaders, the non-distributed datalaoders are put into the file system and the I get them
X_train = ray.get(config["train_data"]) #I load data because I resample at each iteration
# boilerplate loading of some data from file store
torch.distributed.barrier()
for epoch in range(0, 40):
if epoch == 1:
model.module.set_warm() # activate the training on the pretraind network
torch.distributed.barrier()
train_dataloader = torch.utils.data.dataloader.DataLoader(#generate dataloader with negative sampling
train_dataloader = train.torch.prepare_data_loader(train_dataloader)
train_met = train_epoch(model, train_dataloader,
loss_fn, optim, batch_acc)
if (epoch % 5) == 0:
test1_met = test_epoch(model, test_dataloader_1, loss_fn)
test2_met = test_epoch(model, test_dataloader_2, loss_fn)
else: # evaluate some partitions after a few epochs
test1_met = (np.nan, np.nan, np.nan, np.nan)
test2_met = (np.nan, np.nan, np.nan, np.nan)
test3_met = test_epoch(model, test_dataloader_3, loss_fn)
metrics = {"train_loss":train_met[0], "train_roc":train_met[1],
"test1_loss":test1_met[0], "test1_roc":test1_met[1], "test1_acc0":test1_met[2], "test1_acc1":test1_met[3],
"test2_loss":test2_met[0], "test2_roc":test2_met[1], "test2_acc0":test2_met[2], "test2_acc1":test2_met[3],
"test3_loss":test3_met[0], "test3_roc":test3_met[1], "test3_acc0":test3_met[2], "test3_acc1":test3_met[3]}
gc.collect() # maybe this is problematic?
train.report(epoch=epoch, **metrics)
torch.distributed.barrier()
if ray.train.world_rank() == 0: # here I'm trying to save manually to avoid the problem
state_dict = model.module.state_dict()
torch.save(state_dict, "final_model.pkl")
results[epoch] = metrics
torch.distributed.barrier()