Why is my torch AOTInductor model inference class not serializable?

I’m trying to run the following model class using map_batches but I get a pickling error, my model that was exported as a C++ dynamic shared library can’t be pickled.


2024-04-29 21:31:34,948	INFO worker.py:1749 -- Started a local Ray instance.
(36, 1024, 1024)
2024-04-29 19:31:23,827	INFO streaming_executor.py:112 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-04-29_19-27-04_177025_7517/logs/ray-data
2024-04-29 19:31:23,827	INFO streaming_executor.py:113 -- Execution plan of Dataset: InputDataBuffer[Input] -> LimitOperator[limit=20]
- limit=20 1: 0%
0/12 [00:00<?, ?it/s]
Running 0: 0%
0/12 [00:00<?, ?it/s]
(12, 36, 1024, 1024)
(36, 1024, 1024)
2024-04-29 20:11:26,574	INFO streaming_executor.py:112 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-04-29_20-02-05_179416_32477/logs/ray-data
2024-04-29 20:11:26,574	INFO streaming_executor.py:113 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(deserialize_and_convert)] -> LimitOperator[limit=20]
- limit=20: 0 active, 0 queued, [cpu: 0.0, objects: 36.0MB]: 92%
11/12 [00:09<00:00, 1.53it/s]
Running: 1/12.0 CPU, 0/1.0 GPU, 72.0MB/4.4GB object_store_memory: 92%
11/12 [00:09<00:00, 1.53it/s]
(12, 36, 1024, 1024)
2024-04-29 21:34:31,514	INFO streaming_executor.py:112 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-04-29_21-31-33_294023_97768/logs/ray-data
2024-04-29 21:34:31,514	INFO streaming_executor.py:113 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(deserialize_and_convert)] -> ActorPoolMapOperator[MapBatches(MyPredictor)] -> LimitOperator[limit=20]
2024-04-29 21:34:31,536	ERROR exceptions.py:73 -- Exception occurred in Ray Data or Ray Core internal code. If you continue to see this error, please open an issue on the Ray project GitHub page with the full stack trace below: https://github.com/ray-project/ray/issues/new/choose
---------------------------------------------------------------------------
SystemException                           Traceback (most recent call last)
SystemException: 

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
/home/jovyan/work/notebooks/ray-dataloader.ipynb Cell 9 line 1
----> 1 pred = result.take_batch()

File /opt/conda/lib/python3.10/site-packages/ray/data/dataset.py:2320, in Dataset.take_batch(self, batch_size, batch_format)
   2317 limited_ds = self.limit(batch_size)
   2319 try:
-> 2320     res = next(
   2321         iter(
   2322             limited_ds.iter_batches(
   2323                 batch_size=batch_size,
   2324                 prefetch_batches=0,
   2325                 batch_format=batch_format,
   2326             )
   2327         )
   2328     )
   2329 except StopIteration:
   2330     raise ValueError("The dataset is empty.")

File /opt/conda/lib/python3.10/site-packages/ray/data/iterator.py:162, in DataIterator.iter_batches.<locals>._create_iterator()
    157 time_start = time.perf_counter()
    158 # Iterate through the dataset from the start each time
    159 # _iterator_gen is called.
    160 # This allows multiple iterations of the dataset without
    161 # needing to explicitly call `iter_batches()` multiple times.
--> 162 block_iterator, stats, blocks_owned_by_consumer = self._to_block_iterator()
    164 iterator = iter(
    165     iter_batches(
    166         block_iterator,
   (...)
    177     )
    178 )
    180 dataset_tag = self._get_dataset_tag()

File /opt/conda/lib/python3.10/site-packages/ray/data/_internal/iterator/iterator_impl.py:33, in DataIteratorImpl._to_block_iterator(self)
     25 def _to_block_iterator(
     26     self,
     27 ) -> Tuple[
   (...)
     30     bool,
     31 ]:
     32     ds = self._base_dataset
---> 33     block_iterator, stats, executor = ds._plan.execute_to_iterator()
     34     ds._current_executor = executor
     35     return block_iterator, stats, False

File /opt/conda/lib/python3.10/site-packages/ray/data/exceptions.py:86, in omit_traceback_stdout.<locals>.handle_trace(*args, **kwargs)
     84     raise e.with_traceback(None)
     85 else:
---> 86     raise e.with_traceback(None) from SystemException()

TypeError: Could not serialize the put value <ray.data._internal.execution.operators.map_transformer.MapTransformer object at 0x7ef1c0785fc0>:
================================================================================
Checking Serializability of <ray.data._internal.execution.operators.map_transformer.MapTransformer object at 0x7ef1c0785fc0>
================================================================================
!!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
    Serializing '_init_fn' <function _parse_op_fn.<locals>.init_fn at 0x7ef1c4167b50>...
    !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
    Detected 1 global variables. Checking serializability...
        Serializing 'ray' <module 'ray' from '/opt/conda/lib/python3.10/site-packages/ray/__init__.py'>...
    Detected 3 nonlocal variables. Checking serializability...
        Serializing 'fn_constructor_args' ()...
        Serializing 'fn_constructor_kwargs' {}...
        Serializing 'op_fn' <class 'ray.data._internal.execution.util.make_callable_class_concurrent.<locals>._Wrapper'>...
        !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
            Serializing '__call__' <function make_callable_class_concurrent.<locals>._Wrapper.__call__ at 0x7ef1c4167490>...
            !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
    Serializing '_init_fn' <function _parse_op_fn.<locals>.init_fn at 0x7ef1c4167b50>...
    !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
    Detected 1 global variables. Checking serializability...
        Serializing 'ray' <module 'ray' from '/opt/conda/lib/python3.10/site-packages/ray/__init__.py'>...
    Detected 3 nonlocal variables. Checking serializability...
        Serializing 'fn_constructor_args' ()...
        Serializing 'fn_constructor_kwargs' {}...
        Serializing 'op_fn' <class 'ray.data._internal.execution.util.make_callable_class_concurrent.<locals>._Wrapper'>...
        !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
            Serializing '__call__' <function make_callable_class_concurrent.<locals>._Wrapper.__call__ at 0x7ef1c4167490>...
            !!! FAIL serialization: cannot pickle 'torch._C._aoti.AOTIModelContainerRunnerCuda' object
================================================================================
Variable: 

	FailTuple(__call__ [obj=<function make_callable_class_concurrent.<locals>._Wrapper.__call__ at 0x7ef1c4167490>, parent=<class 'ray.data._internal.execution.util.make_callable_class_concurrent.<locals>._Wrapper'>])
FailTuple(__call__ [obj=<function make_callable_class_concurrent.<locals>._Wrapper.__call__ at 0x7ef1c4167490>, parent=<class 'ray.data._internal.execution.util.make_callable_class_concurrent.<locals>._Wrapper'>])

was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. 
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. 
================================================================================
Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information.
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
================================================================================

My model class is below

import torch

class MyPredictor:
    def __init__(self):
        from .utils import load_s3_model
        self.model = load_s3_model("s3://mymodelhub/aot_inductor_gpu_tensor_cores.zip", "cuda")

    def __call__(self, batch):
        inputs = torch.from_numpy(batch).to("cuda").float() / 255
        with torch.no_grad():
            prediction = model(inputs)[0][0]
            if len(prediction.shape) == 4 and prediction.shape[1] > 1:
                output = torch.nn.functional.softmax(prediction, dim=1)
            elif len(prediction.shape) == 4 and prediction.shape[1] == 1:
                output = torch.nn.functional.sigmoid(prediction)
            else:
                raise ValueError(f"Model output shape {prediction.shape} not supported")
        return output.cpu().detach().numpy()

Why can’t the torch._C._aoti.AOTIModelContainerRunnerCuda object be serialized? For reference, I created this model with AOTInductor, which is new in Pytorch 2.2 AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models — PyTorch main documentation. I can’t use a different model format, AOTInductor is necessary for easily deploying the model in a a runtime I don’t have a lot of control over.

I was accidentally using a model global var that was instantiated outside of my class. fixing MyPredictor model > self.model addresses the issue

1 Like