RuntimeError: Some workers returned results while others didn't. Make sure that `` and `train.checkpoint()` are called the same number of times on all workers

Hello I’m trying to use Ray-Train to train my pytorch model using a multi-node setup.
If I use a smoke-test with very small sample size, the model trains. But if I try on the full set,
The training stops with the following error:
RuntimeError: Some workers returned results while others didn’t. Make sure that and train.checkpoint() are called the same number of times on all workers.

I was using train.checkpoint(), I stopped because I saw another person with this problem. But the error persists. If I stop using, then my logging will be a mess.

Other things I tried:
torch.distributed.barrier() (problem persists):

my code:

def train_model(config):
    loss_fn = nn.BCEWithLogitsLoss
    model = IntegrativeModel(# config,
    prot_transformer = TransformerProt()
    prot_transformer = ray.get(config["trans_id"])
    model.prot_transformer = prot_transformer.transformer #loading pretrained
    model = train.torch.prepare_model(model)
    params_to_optimize = [
        {'params': model.parameters()}
    optim = torch.optim.Adam(
        params_to_optimize, lr=config["lr"], weight_decay=config["weight_decay"])
    # loading and preparing dataloaders, the non-distributed datalaoders are put into the file system and the I get them
    X_train = ray.get(config["train_data"]) #I load data because I resample at each iteration
    # boilerplate loading of some data from file store
    for epoch in range(0, 40):
        if epoch == 1:
            model.module.set_warm() # activate the training on the pretraind network
        train_dataloader = dataloader with negative sampling
        train_dataloader = train.torch.prepare_data_loader(train_dataloader)

        train_met = train_epoch(model, train_dataloader,
                          loss_fn, optim, batch_acc)
        if (epoch % 5) == 0:
            test1_met = test_epoch(model, test_dataloader_1, loss_fn)
            test2_met = test_epoch(model, test_dataloader_2, loss_fn)
        else: # evaluate some partitions after a few epochs
            test1_met = (np.nan, np.nan, np.nan, np.nan)
            test2_met = (np.nan, np.nan, np.nan, np.nan)
        test3_met = test_epoch(model, test_dataloader_3, loss_fn)
        metrics = {"train_loss":train_met[0], "train_roc":train_met[1],
                  "test1_loss":test1_met[0], "test1_roc":test1_met[1], "test1_acc0":test1_met[2], "test1_acc1":test1_met[3],
                  "test2_loss":test2_met[0], "test2_roc":test2_met[1], "test2_acc0":test2_met[2], "test2_acc1":test2_met[3],
                  "test3_loss":test3_met[0], "test3_roc":test3_met[1], "test3_acc0":test3_met[2], "test3_acc1":test3_met[3]}
        gc.collect() # maybe this is problematic?, **metrics)
        if ray.train.world_rank() == 0: # here I'm trying to save manually to avoid the problem
            state_dict = model.module.state_dict()
  , "final_model.pkl")
        results[epoch] = metrics

Hey @P_ac, thanks for raising this issue! The way that you are calling looks good to me, so this behavior is definitely odd. Does this work if you remove the gc.collect() line?