[Ray Train] Memory overloading rapidly while training TensorFlow model

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi, I have been experimenting with Ray for some time now, at first with datasets, then for training Scikit-Learn distribution and now for distributinf a Tensorflow Model. However, I am experiencing troubles with memory rapidly growing while fitting this TF model. I am aware that deep-learning models do not distribute in the same manner as Scikit-Learn models and that they should be treated as two separate problems.

My problem with is that I am trying to scale the training of a Tensorflow model built and compiled using Keras. This training works fine when using a small dataset of ~30 observations, but when scaling to 1000 observations (which should not be that much) my cluster’s memory gets filled and the training job crashes.

The dataset I am using consists of a features column containg tensors of ~ 1 000 000 length and multiple labels columns which are encoded using ray.data.preprocessors.LabelEncoder then one-hot encoded into a tensor using a custom class of ray.data.preprocessors.OneHotEncoder.

I have tried to reduce the memory usage of this training with partial success. To do so, I have used the Streaming Ingest as well as Splitting auxiliary datasets described in the documentation for configuring training datasets. I have also used the Dummy Trainer to debug data ingest for the original distribution of the training with Ray.

Here is the code for the trainer definition:
The fit_model() method is part of a class used to preprocess data, build, train and predict using a Tensorflow model. Datasets used for training and validation are stored in a Ray dataset objects but are slightly different. They both contain a __value__ tensor of features column but the training dataset also contains a label tensor of labels column. The dataset used for prediction only contains the __value__ tensor of features column.

import tensorflow as tf
from ray.air import session
from ray.air.integrations.keras import Callback
from ray.air.config import ScalingConfig, DatasetConfig
from ray.train.tensorflow import TensorflowTrainer, TensorflowCheckpoint

from ray.air.config import RunConfig

def fit_model(self, datasets):
    # Transform the datasets using a pre-fitted preprocessor
    # Chained Min-Max scaling, Label Encoding and One-Hot Encoding
    # Min-Max scaling of features and One-Hot Encoding of labels are custom classes to work on tensors directly
    for name, ds in datasets.items():
        ds = self._preprocessor.transform(ds)
        datasets[name] = ds

    # Training parameters
    self._train_params = {
        'batch_size': self.batch_size, # 64
        'epochs': self._training_epochs, # 10
        'size': self._nb_kmers, # Number of features : ~ 1 000 000
        'nb_cls':self._nb_classes, # Number of classes : ~ 48 000
        'model': self.classifier # Model name

    # Define TF trainer
    self._trainer = TensorflowTrainer(
        train_loop_per_worker = train_func, # Training function defined lower
        train_loop_config = self._train_params, # Training parameters defined above
        scaling_config = ScalingConfig(
            trainer_resources={'CPU': 1}, # Default ray training resources https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.air.config.ScalingConfig
            num_workers = self._n_workers, # 3
            use_gpu = self._use_gpu, # Default for testing is False
            resources_per_worker={'CPU': self._nb_CPU_per_worker} # 17
        dataset_config = {
            'train': DatasetConfig(
                fit = False, # Don't fit a preprocessor since none is passed and it is already fitted
                transform = False, # Don't transform the dataset since it is already transformed
                split = True, # Split the dataset accross training workers
                use_stream_api = True # Use the stream API to use DatasetPipeline in training function
            'validation': DatasetConfig(
                fit = False, # Don't fit a preprocessor since this is the validation dataset
                transform = False, # Don't transform the dataset since it is already transformed
                split = True, # Split the dataset accross training workers
                use_stream_api = True # Use the stream API to use DatasetPipeline in training function
        run_config = RunConfig(
            name = self.classifier, # Name of the model
            local_dir = self._workdir, # Path to a directory for spilling data
        datasets = datasets, # {train: Ray dataset, validation: Ray dataset}

    training_result = self._trainer.fit() # Train the model

The training function is outside of the class as mentioned on this Ray discussion subject

import pandas as pd

def train_func(config):
    # Get parameters from config dict
    epochs = config.get('epochs', 10)
    size = config.get('size')
    nb_cls = config.get('nb_cls')
    model = config.get('model')

    # Build/compile model in a distributed manner
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model(model, nb_cls, size)

    # Get dataset shard equivalent to a window from DatasetPipeline passed to the trainer
    train_data = session.get_dataset_shard('train')
    val_data = session.get_dataset_shard('validation')

    # Function to convert a dataset shard into a tf dataset
    def to_tf_dataset(data):
        ds = tf.data.Dataset.from_tensors((
        return ds

    results = []
    # Convert the validation dataset into a tf dataset
    # This must be executed over one epoch/batches because it is a DatasetPipeline shard
    batch_val = pd.DataFrame(columns=['__value__', 'labels'])
    for epoch in val_data.iter_epochs(1):
        for batch in epoch.iter_batches():
            batch_val = pd.concat([batch_val, batch])
    batch_val = to_tf_dataset(batch_val)

    # Fit the model using training DatasetPipeline shard
    for epoch_train in train_data.iter_epochs(epochs): # Iterate over epochs
        for batch_train in epoch_train.iter_batches(): # Iterate over batches
            batch_train = to_tf_dataset(batch_train) # Convert the batch into a tf dataset
            history = model.fit(
                callbacks=[Callback()], # Default ray.air.integrations.keras.Callback
            # Report metrics and checkpoint to the trainer
                'accuracy': history.history['accuracy'][0],
                'loss': history.history['loss'][0],
                'val_accuracy': history.history['val_accuracy'][0],
                'val_loss': history.history['val_loss'][0],

I am using a cluster of 64 CPU cores with 249G RAM.
My python 3.8.10 environment is located in a Singularity container and uses Ray 2.2.0.

Is there a way to reduce the memory usage of this training function that I missed ? Or should I put certain objects into the shared memory using ray.put()?

@nicdemon thanks for submitting

Do you know if the memory issue is related to the node or the GPU memory? Have you tried reducing the number of trainers that are running in parallel, by setting num_workers to a smaller value than 3?