RuntimeError: Some workers returned results while others didn't. Make sure that `train.report()` and `train.checkpoint()` are called the same number of times on all workers

Hello I’m trying to use Ray-Train to train my pytorch model using a multi-node setup.
If I use a smoke-test with very small sample size, the model trains. But if I try on the full set,
The training stops with the following error:
RuntimeError: Some workers returned results while others didn’t. Make sure that train.report() and train.checkpoint() are called the same number of times on all workers.

I was using train.checkpoint(), I stopped because I saw another person with this problem. But the error persists. If I stop using train.report(), then my logging will be a mess.

Other things I tried:
torch.distributed.barrier() (problem persists):

my code:

def train_model(config):
    loss_fn = nn.BCEWithLogitsLoss
    model = IntegrativeModel(# config,
                             warm=False)
   
    prot_transformer = TransformerProt()
    prot_transformer = ray.get(config["trans_id"])
    model.prot_transformer = prot_transformer.transformer #loading pretrained
    model.set_cold()
    model = train.torch.prepare_model(model)
    params_to_optimize = [
        {'params': model.parameters()}
    ]
    optim = torch.optim.Adam(
        params_to_optimize, lr=config["lr"], weight_decay=config["weight_decay"])
    # loading and preparing dataloaders, the non-distributed datalaoders are put into the file system and the I get them
    X_train = ray.get(config["train_data"]) #I load data because I resample at each iteration
    # boilerplate loading of some data from file store
    torch.distributed.barrier()
    for epoch in range(0, 40):
        if epoch == 1:
            model.module.set_warm() # activate the training on the pretraind network
            torch.distributed.barrier()
        train_dataloader = torch.utils.data.dataloader.DataLoader(#generate dataloader with negative sampling
        train_dataloader = train.torch.prepare_data_loader(train_dataloader)

       
        train_met = train_epoch(model, train_dataloader,
                          loss_fn, optim, batch_acc)
        if (epoch % 5) == 0:
            test1_met = test_epoch(model, test_dataloader_1, loss_fn)
            test2_met = test_epoch(model, test_dataloader_2, loss_fn)
        else: # evaluate some partitions after a few epochs
            test1_met = (np.nan, np.nan, np.nan, np.nan)
            test2_met = (np.nan, np.nan, np.nan, np.nan)
        test3_met = test_epoch(model, test_dataloader_3, loss_fn)
        metrics = {"train_loss":train_met[0], "train_roc":train_met[1],
                  "test1_loss":test1_met[0], "test1_roc":test1_met[1], "test1_acc0":test1_met[2], "test1_acc1":test1_met[3],
                  "test2_loss":test2_met[0], "test2_roc":test2_met[1], "test2_acc0":test2_met[2], "test2_acc1":test2_met[3],
                  "test3_loss":test3_met[0], "test3_roc":test3_met[1], "test3_acc0":test3_met[2], "test3_acc1":test3_met[3]}
        gc.collect() # maybe this is problematic?
        train.report(epoch=epoch, **metrics)
        torch.distributed.barrier()
        if ray.train.world_rank() == 0: # here I'm trying to save manually to avoid the problem
            state_dict = model.module.state_dict()
            torch.save(state_dict, "final_model.pkl")
        results[epoch] = metrics
        torch.distributed.barrier()

Hey @P_ac, thanks for raising this issue! The way that you are calling train.report looks good to me, so this behavior is definitely odd. Does this work if you remove the gc.collect() line?