Checkpointing errors on complex models

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello, my model is created from two neural network one is Vnet which supplies segmentation masks second is modelRegression which count the class instances in the segmentation map.
both are kept in the lightning model class

        self.net = net
        self.modelRegression = UNetToRegresion(2,regression_channels)

However when I try the checkpointing the load_state_dict function gives error, what can I do?

trainer.lightning_module.load_state_dict(state_dict)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Model:
        Missing key(s) in state_dict: "modelRegression.model.0.weight_fake_quant.min_vals", "modelRegression.model.0.weight_fake_quant.max_vals", 

strangely error occur both when using TuneReportCallback and TuneReportCheckpointCallback

full error

2022-09-22 15:25:20,952 ERROR tune.py:754 -- Trials did not complete: [mainTrain_722c5_00000]
2022-09-22 15:25:20,952 INFO tune.py:758 -- Total run time: 569.10 seconds (568.96 seconds for the tuning loop).
At least one trial failed.
The trial had an error: ray::ImplicitFunc.train() (pid=277, ip=10.164.0.3, repr=mainTrain)
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/trainable.py", line 347, in train
    result = self.step()
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/function_trainable.py", line 417, in step
    self._report_thread_runner_error(block=True)
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/function_trainable.py", line 589, in _report_thread_runner_error
    raise e
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/function_trainable.py", line 289, in run
    self._entrypoint()
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/function_trainable.py", line 362, in entrypoint
    return self._trainable_func(
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/function_trainable.py", line 684, in _trainable_func
    output = fn()
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/util.py", line 359, in inner
    trainable(config, **fn_kwargs)
  File "/home/sliceruser/data/piCaiCode/Three_chan_baseline.py", line 165, in mainTrain
    ThreeChanNoExperiment.train_model(label_name, dummyLabelPath, df,percentSplit,cacheDir
  File "/home/sliceruser/data/piCaiCode/ThreeChanNoExperiment.py", line 248, in train_model
    trainer.fit(model=model, datamodule=data)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/ray_lightning/launchers/ray_launcher.py", line 64, in launch
    self._recover_results_in_main_process(ray_output, trainer)
  File "/usr/local/lib/python3.8/dist-packages/ray_lightning/launchers/ray_launcher.py", line 370, in _recover_results_in_main_process
    trainer.lightning_module.load_state_dict(state_dict)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Model:
        Missing key(s) in state_dict: "modelRegression.model.0.weight_fake_quant.min_vals", "modelRegression.model.0.weight_fake_quant.max_vals", "modelRegression.model.1.weight_fake_quant.min_vals", "modelRegression.model.1.weight_fake_quant.max_vals", "modelRegression.model.2.weight_fake_quant.min_vals", "modelRegression.model.2.weight_fake_quant.max_vals", "modelRegression.model.3.weight_fake_quant.min_vals", "modelRegression.model.3.weight_fake_quant.max_vals". 
        size mismatch for modelRegression.model.0.weight_fake_quant.min_val: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.0.weight_fake_quant.max_val: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.1.weight_fake_quant.min_val: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.1.weight_fake_quant.max_val: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.2.weight_fake_quant.min_val: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.2.weight_fake_quant.max_val: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.3.weight_fake_quant.min_val: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for modelRegression.model.3.weight_fake_quant.max_val: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([0]).

Hey @Jakub_Mitura- would you be able to share a reproducible script? Is this using population based training scheduler?

Thanks for responding!, yes it does use population based scheduler, ok i will work on some mini.al example today

Hello @amogkam I was futher working on minimal working example. Currently in simplified case I can not reproduce yet behaviour related to two networks problem, Hovewer I suppose that I Identified some errors related to passing metrics through callback. Below I added two minimal working examples representing exactly when such problem occurs.
Thanks!

pytorch-lightning 1.6.5
ray 2.0.0
ray-lightning 0.3.0

full list of python packages

As indicated in part of the code with stars *** TuneReportCheckpointCallback gives error when TuneReportCallback do not

code for not finding metrics becouse of checkpointing

"""Simple example using RayAccelerator and Ray Tune"""
import functools
import glob
import importlib.util
import math
import multiprocessing as mp
import operator
import os
import shutil
import sys
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from glob import glob
from os import path as pathOs
from os.path import basename, dirname, exists, isdir, join, split
from pathlib import Path
#from picai_eval.picai_eval import evaluate_case
from statistics import mean
from typing import List, Optional, Sequence, Tuple, Union

import gdown
import matplotlib.pyplot as plt
import monai
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import ray
import seaborn as sns
import SimpleITK as sitk
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio
import torchio as tio
import torchmetrics
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pytorch_lightning import (Callback, LightningDataModule, LightningModule,
                               Trainer)
from pytorch_lightning.strategies import Strategy
from ray import air, tune
from ray.air import session
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback, TuneReportCheckpointCallback)
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray_lightning import RayShardedStrategy, RayStrategy
from ray_lightning.tune import TuneReportCallback, get_tune_resources
from report_guided_annotation import extract_lesion_candidates
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from torch.nn.intrinsic.qat import ConvBnReLU3d
from torch.utils.cpp_extension import load
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics import Precision
from torchmetrics.functional import precision_recall

ray.init(num_cpus=24)
data_dir = '/home/sliceruser/mnist'
MNISTDataModule(data_dir=data_dir).prepare_data()
num_cpus_per_worker=6
test_l_dir = '/home/sliceruser/test_l_dir'

class netaA(nn.Module):
    def __init__(self,
        config
    ) -> None:
        super().__init__()
        layer_1, layer_2 = config["layer_1"], config["layer_2"]
        self.model = nn.Sequential(
        torch.nn.Linear(28 * 28, layer_1),
        torch.nn.Linear(layer_1, layer_2),    
        torch.nn.Linear(layer_2, 10)
        )
    def forward(self, x):
        return self.model(x)



class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]

        self.accuracy = torchmetrics.Accuracy()
        self.netA= netaA(config)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x= self.netA(x)

        x = F.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)



def train_mnist(config,
                data_dir=None,
                num_epochs=10,
                num_workers=1,
                use_gpu=True,
                callbacks=None):

    model = LightningMNISTClassifier(config, data_dir)

    callbacks = callbacks or []
    print(" aaaaaaaaaa  ")
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=callbacks,
        progress_bar_refresh_rate=0,
        strategy=RayStrategy(
            num_workers=num_workers, use_gpu=use_gpu))#, init_hook=download_data
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=2, batch_size=config["batch_size"])
    trainer.fit(model, dm)


def tune_mnist(data_dir,
               num_samples=2,
               num_epochs=10,
               num_workers=2,
               use_gpu=True):
    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }

    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
   
   #***********************************************
    #do not work
    callbacks = [TuneReportCheckpointCallback(metrics, on="validation_end",filename="checkpointtt")]
    
    #works
    #callbacks = [TuneReportCallback(metrics, on="validation_end")]
 
    #***********************************************

 
 
    trainable = tune.with_parameters(
        train_mnist,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_workers=num_workers,
        use_gpu=use_gpu,
        callbacks=callbacks)
    analysis = tune.run(
        trainable,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        resources_per_trial=get_tune_resources(
            num_workers=num_workers, use_gpu=use_gpu),
        name="tune_mnist")

    print("Best hyperparameters found were: ", analysis.best_config)

tune_mnist(data_dir)


error

simmilar (although not the same) error happen when trying to use population based scheduler

first working example and below giving error

"""Simple example using RayAccelerator and Ray Tune"""
import functools
import glob
import importlib.util
import math
import multiprocessing as mp
import operator
import os
import shutil
import sys
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from glob import glob
from os import path as pathOs
from os.path import basename, dirname, exists, isdir, join, split
from pathlib import Path
#from picai_eval.picai_eval import evaluate_case
from statistics import mean
from typing import List, Optional, Sequence, Tuple, Union

import gdown
import matplotlib.pyplot as plt
import monai
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import ray
import seaborn as sns
import SimpleITK as sitk
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio
import torchio as tio
import torchmetrics
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pytorch_lightning import (Callback, LightningDataModule, LightningModule,
                               Trainer)
from pytorch_lightning.strategies import Strategy
from ray import air, tune
from ray.air import session
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback, TuneReportCheckpointCallback)
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray_lightning import RayShardedStrategy, RayStrategy
from ray_lightning.tune import TuneReportCallback, get_tune_resources
from report_guided_annotation import extract_lesion_candidates
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from torch.nn.intrinsic.qat import ConvBnReLU3d
from torch.utils.cpp_extension import load
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics import Precision
from torchmetrics.functional import precision_recall
from ray.tune.schedulers.pb2 import PB2

ray.init(num_cpus=24)
data_dir = '/home/sliceruser/mnist'
#MNISTDataModule(data_dir=data_dir).prepare_data()
num_cpus_per_worker=6
test_l_dir = '/home/sliceruser/test_l_dir'

class netaA(nn.Module):
    def __init__(self,
        config
    ) -> None:
        super().__init__()
        layer_1, layer_2 = config["layer_1"], config["layer_2"]
        self.model = nn.Sequential(
        torch.nn.Linear(28 * 28, layer_1),
        torch.nn.Linear(layer_1, layer_2),    
        torch.nn.Linear(layer_2, 10)
        )
    def forward(self, x):
        return self.model(x)

class netB(nn.Module):
    def __init__(self,
        config
    ) -> None:
        super().__init__()
        self.model = nn.Sequential(
        torch.nn.Linear(10, 24),
        torch.nn.Linear(24, 100),    
        torch.nn.Linear(100, 10)
        )
    def forward(self, x):
        return self.model(x)

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]

        self.accuracy = torchmetrics.Accuracy()
        self.netA= netaA(config)
        self.netB= netB(config)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x= self.netA(x)
        x= self.netB(x)

        x = F.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)



def train_mnist(config,
                data_dir=None,
                num_epochs=10,
                num_workers=1,
                use_gpu=True,
                callbacks=None):

    model = LightningMNISTClassifier(config, data_dir)




    callbacks = callbacks or []
    print(" aaaaaaaaaa  ")
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=callbacks,
        progress_bar_refresh_rate=0,
        strategy=RayStrategy(
            num_workers=num_workers, use_gpu=use_gpu)
            
            )#, init_hook=download_data

    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=2, batch_size=config["batch_size"])
    trainer.fit(model, dm)


def tune_mnist(data_dir,
               num_samples=2,
               num_epochs=10,
               num_workers=2,
               use_gpu=True):
    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": 1e-3,
        "batch_size": tune.choice([32, 64, 128]),
    }

    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
   

    callbacks = [TuneReportCallback(metrics, on="validation_end")]
 
    #***********************************************
    # pb2_scheduler = PB2(
    #     time_attr="training_iteration",
    #     metric='acc',
    #     mode='max',
    #     perturbation_interval=10.0,
    #     hyperparam_bounds={
    #         "lr": [1e-2, 1e-5],
    #     })
 
 
    trainable = tune.with_parameters(
        train_mnist,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_workers=num_workers,
        use_gpu=use_gpu,
        callbacks=callbacks)
    analysis = tune.run(
        trainable,
        #scheduler=pb2_scheduler,
        metric="acc",
        mode="max",
        config=config,
        num_samples=num_samples,
        resources_per_trial=get_tune_resources(
            num_workers=num_workers, use_gpu=use_gpu),
        name="tune_mnist")
   #***********************************************

    print("Best hyperparameters found were: ", analysis.best_config)

tune_mnist(data_dir)

example giving error

"""Simple example using RayAccelerator and Ray Tune"""
import functools
import glob
import importlib.util
import math
import multiprocessing as mp
import operator
import os
import shutil
import sys
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from glob import glob
from os import path as pathOs
from os.path import basename, dirname, exists, isdir, join, split
from pathlib import Path
#from picai_eval.picai_eval import evaluate_case
from statistics import mean
from typing import List, Optional, Sequence, Tuple, Union

import gdown
import matplotlib.pyplot as plt
import monai
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import ray
import seaborn as sns
import SimpleITK as sitk
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio
import torchio as tio
import torchmetrics
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pytorch_lightning import (Callback, LightningDataModule, LightningModule,
                               Trainer)
from pytorch_lightning.strategies import Strategy
from ray import air, tune
from ray.air import session
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback, TuneReportCheckpointCallback)
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray_lightning import RayShardedStrategy, RayStrategy
from ray_lightning.tune import TuneReportCallback, get_tune_resources
from report_guided_annotation import extract_lesion_candidates
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from torch.nn.intrinsic.qat import ConvBnReLU3d
from torch.utils.cpp_extension import load
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics import Precision
from torchmetrics.functional import precision_recall
from ray.tune.schedulers.pb2 import PB2

ray.init(num_cpus=24)
data_dir = '/home/sliceruser/mnist'
#MNISTDataModule(data_dir=data_dir).prepare_data()
num_cpus_per_worker=6
test_l_dir = '/home/sliceruser/test_l_dir'

class netaA(nn.Module):
    def __init__(self,
        config
    ) -> None:
        super().__init__()
        layer_1, layer_2 = config["layer_1"], config["layer_2"]
        self.model = nn.Sequential(
        torch.nn.Linear(28 * 28, layer_1),
        torch.nn.Linear(layer_1, layer_2),    
        torch.nn.Linear(layer_2, 10)
        )
    def forward(self, x):
        return self.model(x)

class netB(nn.Module):
    def __init__(self,
        config
    ) -> None:
        super().__init__()
        self.model = nn.Sequential(
        torch.nn.Linear(10, 24),
        torch.nn.Linear(24, 100),    
        torch.nn.Linear(100, 10)
        )
    def forward(self, x):
        return self.model(x)

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]

        self.accuracy = torchmetrics.Accuracy()
        self.netA= netaA(config)
        self.netB= netB(config)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x= self.netA(x)
        x= self.netB(x)

        x = F.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)



def train_mnist(config,
                data_dir=None,
                num_epochs=10,
                num_workers=1,
                use_gpu=True,
                callbacks=None):

    model = LightningMNISTClassifier(config, data_dir)




    callbacks = callbacks or []
    print(" aaaaaaaaaa  ")
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=callbacks,
        progress_bar_refresh_rate=0,
        strategy=RayStrategy(
            num_workers=num_workers, use_gpu=use_gpu)
            
            )#, init_hook=download_data

    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=2, batch_size=config["batch_size"])
    trainer.fit(model, dm)


def tune_mnist(data_dir,
               num_samples=2,
               num_epochs=10,
               num_workers=2,
               use_gpu=True):
    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": 1e-3,
        "batch_size": tune.choice([32, 64, 128]),
    }

    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
   

    callbacks = [TuneReportCallback(metrics, on="validation_end")]
 
    #***********************************************
    pb2_scheduler = PB2(
        time_attr="training_iteration",
        metric='acc',
        mode='max',
        perturbation_interval=10.0,
        hyperparam_bounds={
            "lr": [1e-2, 1e-5],
        })
 
 
    trainable = tune.with_parameters(
        train_mnist,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_workers=num_workers,
        use_gpu=use_gpu,
        callbacks=callbacks)
    analysis = tune.run(
        trainable,
        scheduler=pb2_scheduler,
        # metric="acc",
        # mode="max",
        config=config,
        num_samples=num_samples,
        resources_per_trial=get_tune_resources(
            num_workers=num_workers, use_gpu=use_gpu),
        name="tune_mnist")
   #***********************************************

    print("Best hyperparameters found were: ", analysis.best_config)

tune_mnist(data_dir)

full error