Ray Tune Error "The actor ImplicitFunc is too large"

Hi,
I am using Ray Tune 2.0.0 with Pytorch Lightning 1.7.5 for hyperoptimization.

I set up a Resnet50 model in Pytorch Lightning and I followed the instruction to use Ray Tune as shown here: https://docs.ray.io/en/releases-1.11.1/tune/tutorials/tune-pytorch-lightning.html

I get the following error and I think it has to do something with passing the data into the ray object. I also saw different conversations on this forum dealing with the same issue, but e.g. tune.with_parameters() did not solve my problem.

ValueError: The actor ImplicitFunc is too large (122 MiB > FUNCTION_SIZE_ERROR_THRESHOLD=95 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.

Here is my code:


dataset_train = TensorDataset(X_train_T.float(),y_train_T.float())
dataset_test = TensorDataset(X_test_T.float(),y_test_T.float())
dataset_val = TensorDataset(X_val_T.float(),y_val_T.float())

###### ResNet Block ########

class ResBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels, downsample):
        super(ResBlock,self).__init__()
        if downsample:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels))
            
        else:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = nn.Sequential()

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = nn.ReLU()(self.bn1(self.conv1(input)))
        input = nn.ReLU()(self.bn2(self.conv2(input)))
        input = input + shortcut
        return nn.ReLU()(input)


###### ResNet Bottleneck Block ########


class ResBottleneckBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels, downsample):
        super(ResBottleneckBlock,self).__init__()
        self.downsample = downsample
        self.conv1 = nn.Conv2d(in_channels, out_channels//4,
                               kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(
            out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1)
        self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1)

        if self.downsample or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=2 if self.downsample else 1),
                nn.BatchNorm2d(out_channels))
            
        else:
            self.shortcut = nn.Sequential()

        self.bn1 = nn.BatchNorm2d(out_channels//4)
        self.bn2 = nn.BatchNorm2d(out_channels//4)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = nn.ReLU()(self.bn1(self.conv1(input)))
        input = nn.ReLU()(self.bn2(self.conv2(input)))
        input = nn.ReLU()(self.bn3(self.conv3(input)))
        input = input + shortcut
        return nn.ReLU()(input)


###### ResNet Module ########

class ResNet(pl.LightningModule):
    def __init__(self, in_channels, resblock, repeat, useBottleneck=True, outputs=1000):
        super(ResNet,self).__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
      
        if useBottleneck:
            filters = [64, 256, 512, 1024, 2048, 4096]
        else:
            filters = [64, 64, 128, 256, 4096]

        self.layer1 = nn.Sequential()
        self.layer1.add_module('conv2_1', resblock(filters[0], filters[1], downsample=False))
        for i in range(1, repeat[0]):
                self.layer1.add_module('conv2_%d'%(i+1,), resblock(filters[1], filters[1], downsample=False))

        self.layer2 = nn.Sequential()
        self.layer2.add_module('conv3_1', resblock(filters[1], filters[2], downsample=True))
        for i in range(1, repeat[1]):
                self.layer2.add_module('conv3_%d' % (
                    i+1,), resblock(filters[2], filters[2], downsample=False))

        self.layer3 = nn.Sequential()
        self.layer3.add_module('conv4_1', resblock(filters[2], filters[3], downsample=True))
        for i in range(1, repeat[2]):
            self.layer3.add_module('conv2_%d' % (
                i+1,), resblock(filters[3], filters[3], downsample=False))

        self.layer4 = nn.Sequential()
        self.layer4.add_module('conv5_1', resblock(filters[3], filters[4], downsample=True))
        for i in range(1, repeat[3]):
            self.layer4.add_module('conv3_%d'%(i+1,),resblock(filters[4], filters[4], downsample=False))

        self.gap = nn.AdaptiveAvgPool2d(1)
         
        
        self.fc0 = torch.nn.Linear(filters[4], 4200)
       

    def forward(self, input):
        input = self.layer0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        input = self.layer4(input)
        input = self.gap(input)
        input = torch.flatten(input, start_dim=1)
        input = self.fc0(input)
        input = torch.sigmoid(input)
       
        return input


class LightNet(pl.LightningModule):
    
    def __init__(self,config):

        super().__init__()
       
        self.model = model
        self.lr = config["lr"]
        self.pearson_r = PearsonCorrCoef()        
        self.criterion = F.mse_loss
        self.batch_size = 16

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss,on_epoch=True,on_step=False,sync_dist=True)#, prog_bar=True, logger=True)
        return loss 
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss_ = self.criterion(y_hat, y)
        
        self.log("val_mse_loss",loss_, on_epoch=True, sync_dist=True) 
               
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)

    def val_dataloader(self):
        return DataLoader(self.dataset_val,num_workers=0,pin_memory='False',batch_size=self.batch_size,drop_last=True) 
       
    def test_dataloader(self):
        return DataLoader(self.dataset_test, batch_size=self.batch_size,num_workers=0,pin_memory='False',drop_last=True))
      
    def train_dataloader(self): 
        return DataLoader(self.dataset_train, batch_size=self.batch_size,num_workers=0,pin_memory='False',drop_last=True)

     
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)


        return optimizer


model = ResNet(1, ResBottleneckBlock, [3, 4, 6, 3], useBottleneck=True, outputs=4200)


############ Ray Tuning #############################################

def train_resnet_tune(config,num_epochs=5, num_gpus=7,data = None):
    

    regressor = LightNet(config)
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        logger=TensorBoardLogger(
            save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=0,
        callbacks=[
            TuneReportCallback(
                {
                    #"loss": "ptl/val_loss",
                    "mean_accuracy": "ptl/val_mse_loss"
                },
                on="validation_end")])
       
    trainer.fit(regressor)


def tune_resnet_hrw(num_samples=5, num_epochs=5, gpus_per_trial=7, data_dir="~/tomo/data"):
    config = {'lr': tune.loguniform(1e-4, 1e-2)}

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(parameter_columns=["lr"],metric_columns=['val_mse_loss'])

    train_fn_with_parameters = tune.with_parameters(train_resnet_tune,
                                                    num_epochs=num_epochs,
                                                    num_gpus=gpus_per_trial)
      
    resources_per_trial = {"cpu": 50, "gpu": gpus_per_trial}

    analysis = tune.run(train_fn_with_parameters,    
        resources_per_trial=resources_per_trial,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="tune_resnet_hrw")

    print("Best hyperparameters found were: ", analysis.best_config)


tune_resnet_hrw(num_samples=5, num_epochs=5, gpus_per_trial=7)

Thank you very much for any advice you can give

beste regards
Moritz