Disk spilling causes out of disk error in a job

Hi Folks! I am have been running into disk issues with spilling and can’t seem to get around it. My use case is take a bunch of image url/references and then run image download + preprocessing , then model inference, and then save it to s3 in a custom location. I created 3 actors for preprocessing, inference, and saving output , and then used map_batches to stitch them up in a data pipeline.

Executing DAG InputDataBuffer[Input] → TaskPoolMapOperator[ReadText] → ActorPoolMapOperator[MapBatches(PreProcessActor)] → ActorPoolMapOperator[MapBatches(InferenceActor)] → ActorPoolMapOperator[MapBatches(SaveOutputActor)] → TaskPoolMapOperator[Write]

Preprocessing takes only 40 - 60 ms, inference takes < 300 ms and saving takes < 120 ms. The input to preprocessing is very small (~ 120M image urls) but the output tensor is big and the output stays big for the rest of the pipeline. I specify read parallelism as 10000 to decrease block size in the data.read_xxx call otherwise output from preprocessing leads to OOM. I see too much disk spilling happening and then ultimately the job runs into out of disk space error. I believe that the backpressure in data pipeline should have avoided the disk spilling but it does not seem to be working. Also I think I was not hitting disk spilling issue till today with the same code (I don’t know if i picked up some new ray change in a new Ray cluster install). Could somebody suggest a fix ? I have attached my code for your reference

"""
Extract YellowRanger image sequence features. The supported batch size is 1.
The sequence length is 257 for an image. The embedding size is 256. The first
token is [CLS] token.
"""

import io
import os
import base64
import time
import urllib
import hashlib

import argparse
import numpy as np
from PIL import Image, ImageFile
import torch
from torch import nn
from torch.backends import cudnn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import ray
from utils.copy_files_from_s3_to_all_nodes import process_all_s3_files_on_all_nodes, process_all_s3_files_list_at_path
from ray.data.extensions import TensorArray
from ray.data import ActorPoolStrategy
import errno
import shutil
from botocore.exceptions import ClientError
import boto3
from ray.autoscaler.sdk import request_resources
from ray.util.metrics import Counter, Gauge, Histogram
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from pyarrow.fs import AwsDefaultS3RetryStrategy, S3FileSystem

ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import AutoTokenizer
from torchvision.transforms import (
    Compose,
    Normalize,
    ToTensor,
    Resize,
    CenterCrop,
    InterpolationMode,
)

UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"
MASK_TOKEN = "[MASK]"
SEP_TOKEN = "[SEP]"
CLS_TOKEN = "[CLS]"

SPM_UNK_TOKEN = "<unk>"
SPM_PAD_TOKEN = "<pad>"

WPIECE_UNDERLINE = "##"
SPIECE_UNDERLINE = "▁"
default_image_size = 224
DEFAULT_IMAGE = Image.fromarray(np.zeros([default_image_size, default_image_size, 3], dtype=np.uint8)).convert("RGB")
clip_normalize = Compose(
    [Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])



# lets explicitly not create a new cluster
NUM_WORKER_NODES = 200
NUM_CUSTOM_CPU_VCPU_PER_NODE = 4
CUSTOM_CPU_VCPU = NUM_CUSTOM_CPU_VCPU_PER_NODE * NUM_WORKER_NODES
TOTAL_NUM_NODES = NUM_WORKER_NODES + 1

ray.init(address="auto")

"""
# This is not needed
# can specify cpu and gpu
request_resources(bundles=[{"custom_gpu_vcpu_inference": NUM_CUSTOM_CPU_VCPU_PER_NODE}] * NUM_WORKER_NODES)
# waiting for cluster nodes to come up.
while len(ray.nodes()) < TOTAL_NUM_NODES:
    print(f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}")
    time.sleep(5)
"""


class PreProcessActor:
    def __init__(self):
        self.image_size=default_image_size
        self.transform =  Compose(
            [
                Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
                ToTensor(),
                clip_normalize,
            ]
        )

        # Metrics
        self.name = "preprocess"
        self._curr_count = 0

        self.counter = Counter(
            "num_requests",
            description="Number of requests processed by the actor.",
            tag_keys=("actor_name",),
        )
        self.counter.set_default_tags({"actor_name": self.name})


        self.gauge = Gauge(
            "curr_input_count",
            description="Current input rows count held by the actor. Goes up and down.",
            tag_keys=("actor_name",),
        )
        self.gauge.set_default_tags({"actor_name": self.name})


        self.histogram = Histogram(
            "request_latency",
            description="Latencies of requests in ms.",
            boundaries=[0.1, 10, 20, 40, 60, 80, 100, 200],
            tag_keys=("actor_name",),
        )
        self.histogram.set_default_tags({"actor_name": self.name})

    def process_image(self, img_path):
        """Read an image from path and transfrom it accordingly."""
        # img_path is a panda row
        isvalid = True
        try:
            img_url = 'https://m.media-amazon.com/images/I/%s' % (img_path.img_id.replace('.', '._AC_SL320_.'))
            img = Image.open(urllib.request.urlopen(img_url)).convert('RGB')
        except:
            try:
                #print("trying to load the image from local path...")
                img = Image.open(img_path).convert('RGB')
                #print('Succeed in loading the image from local path.')
            except:
                #print("failed loading image: ", img_url)
                img = Image.fromarray(np.zeros((default_image_size, default_image_size, 3), dtype=np.uint8))
                isvalid = False
        img = self.transform(img)
        img = torch.tensor(img).unsqueeze(0)
        return img, isvalid

    def __call__(self, batch):
        #img_ids = batch['text'].to_list()
        #batch['text'].map(lambda x: self.read_image(x, self.transform))
        # text is the first column

        # Metrics
        start = time.time()
        num = batch.shape[0]
        self._curr_count += num
        # Increment the total request count.
        self.counter.inc()
        # Update the gauge to the new value.
        self.gauge.set(self._curr_count)



        batch.rename(columns={'text': 'img_id'}, inplace=True)
        batch[['img', 'isvalid']] = batch.apply(self.process_image, axis=1, result_type='expand')
        batch['img_id'] = batch['img_id'].map(lambda x: x.split('/')[-1])

        # Record the latency for this request in ms.
        self.histogram.observe(1000 * (time.time() - start))

        time.sleep(0.1)
        return batch



class InferenceActor:
    def __init__(self, bucket, prefix, download_path,  model_path):


        """
        self.model = ray.get(model_ref)
        self.model.eval()
        self.model.to("cuda")
        """
        # copy model locally
        refs = []
        node_id = ray.get_runtime_context().node_id
        scheduling_strategy = NodeAffinitySchedulingStrategy(
            node_id=node_id, soft=False
        )
        refs.append(
            process_all_s3_files_list_at_path.options(scheduling_strategy=scheduling_strategy).remote(bucket, prefix, download_path, parallelism=1)
        )
        ray.get(refs)
        print("model copied on node:", node_id)



        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.dummy_token_ids = [0, 2]
        self.DUMMY_TEXT_TOKENS = torch.tensor(self.dummy_token_ids, dtype=torch.int64, device="cpu").unsqueeze(0)
        self.DUMMY_TEXT_LEN = torch.tensor(len(self.dummy_token_ids), dtype=torch.int64, device="cpu").unsqueeze(0)
        self.model = torch.jit.load(model_path, map_location=self.device)

        # Metrics
        self.name = "inference"
        self._curr_count = 0

        self.counter = Counter(
            "num_requests",
            description="Number of requests processed by the actor.",
            tag_keys=("actor_name",),
        )
        self.counter.set_default_tags({"actor_name": self.name})


        self.gauge = Gauge(
            "curr_input_count",
            description="Current input count held by the actor. Goes up and down.",
            tag_keys=("actor_name",),
        )
        self.gauge.set_default_tags({"actor_name": self.name})


        self.histogram = Histogram(
            "request_latency",
            description="Latencies of requests in ms.",
            boundaries=[0.1, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500],
            tag_keys=("actor_name",),
        )
        self.histogram.set_default_tags({"actor_name": self.name})

    def __call__(self, batch):
        # Metrics
        start = time.time()
        num = batch.shape[0]
        self._curr_count += num
        # Increment the total request count.
        self.counter.inc()
        # Update the gauge to the new value.
        self.gauge.set(self._curr_count)

        #print("Inference called")
        with torch.inference_mode():
            features_list = []
            inference_start = time.time()
            for row in batch.itertuples():
                data = torch.as_tensor(row.img, device=self.device)
                #print("input ", data.shape)
                output = self.model(self.DUMMY_TEXT_TOKENS.to(self.device), data, self.DUMMY_TEXT_LEN.to(self.device))
                output = output.cpu().numpy().squeeze()
                #print("outut ", output.shape)
                features_list.append(TensorArray(output))
            batch["features"] = features_list
            # Original preprocessed features no onger needed
            batch.drop(columns=['img'])
            #print("time cost, inference: ", time.time() - inference_start)

            # Record the latency for this request in ms.
            self.histogram.observe(1000 * (time.time() - start))

            return batch






class SaveOutputActor:

    def assert_dir_exists(self, path):
        """
        Checks if directory tree in path exists. If not it created them.
        :param path: the path to check if it exists
        """
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    def __init__(self, s3_bucket, s3_prefix, local_download_path):
        self.s3_client = boto3.client('s3')
        self.s3_bucket = s3_bucket
        self.s3_prefix = s3_prefix
        rtc = ray.get_runtime_context()
        actor_id = rtc.get_actor_id()
        self.local_download_path_images = os.path.join(local_download_path, actor_id + '-features')
        self.assert_dir_exists(self.local_download_path_images)

        # Metrics
        self.name = "saveoutput"
        self._curr_count = 0

        self.counter = Counter(
            "num_requests",
            description="Number of requests processed by the actor.",
            tag_keys=("actor_name",),
        )
        self.counter.set_default_tags({"actor_name": self.name})


        self.gauge = Gauge(
            "curr_input_count",
            description="Current input count held by the actor. Goes up and down.",
            tag_keys=("actor_name",),
        )
        self.gauge.set_default_tags({"actor_name": self.name})


        self.histogram = Histogram(
            "request_latency",
            description="Latencies of requests in ms.",
            boundaries=[0.1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 160, 180, 200, 250, 300, 350, 400, 450, 500],
            tag_keys=("actor_name",),
        )
        self.histogram.set_default_tags({"actor_name": self.name})
        pass

    def calculate_checksum(self, string):
        # Create a new SHA-256 hash object
        sha256_hash = hashlib.sha256()

        # Convert the string to bytes and update the hash object
        sha256_hash.update(string.encode('utf-8'))

        # Get the hexadecimal representation of the checksum
        checksum = sha256_hash.hexdigest()

        return checksum

    def upload_file(self, file_name, bucket, object_name=None):
        """Upload a file to an S3 bucket

        :param file_name: File to upload
        :param bucket: Bucket to upload to
        :param object_name: S3 object name. If not specified then file_name is used
        :return: True if file was uploaded, else False
        """

        # If S3 object_name was not specified, use file_name
        if object_name is None:
            object_name = os.path.basename(file_name)

        # Upload the file
        try:
            response = self.s3_client.upload_file(file_name, bucket, object_name)
        except ClientError as e:
            print("Error : {}".format(str(e)))
            return False
        return True

    def __call__(self, batch):
        # Metrics
        start = time.time()
        num = batch.shape[0]
        self._curr_count += num
        # Increment the total request count.
        self.counter.inc()
        # Update the gauge to the new value.
        self.gauge.set(self._curr_count)

        self.assert_dir_exists(self.local_download_path_images)
        return_dict = {}
        img_id_list = []
        cs_id_list = []
        for row in batch.itertuples():
            # Calculate the checksum of the image id for saving
            try:
                if row.isvalid:
                    cs_id = self.calculate_checksum(row.img_id)  # Mind the batch size
                    file_name = cs_id + '.npy'
                    s3_folder = cs_id[0] + '/' + cs_id[1] + '/' + cs_id[2] + '/' + cs_id[3] + '/' + cs_id[4] + '/'
                    s3_folder = os.path.join(self.s3_prefix, s3_folder)
                    s3_file = os.path.join(s3_folder, file_name)
                    save_file = os.path.join(self.local_download_path_images, file_name)
                    with open(save_file, 'wb') as f:
                        np.save(f, row.features)
                    self.upload_file(save_file, self.s3_bucket, s3_file)
                    img_id_list.append(row.img_id)
                    cs_id_list.append(cs_id)
            except Exception as e:
                print("Caught Exception in writing S3 file {}: {}".format(s3_file, str(e)))
                # Just ignore excetions
                continue
                pass
        shutil.rmtree(self.local_download_path_images)
        return_dict["img_id"] = img_id_list
        return_dict["cs"] = cs_id_list

        # Record the latency for this request in ms.
        self.histogram.observe(1000 * (time.time() - start))

        return return_dict

import sys


if __name__ == "__main__":
    # argparser
    parser = argparse.ArgumentParser(description="pytorch feature extraction.")
    parser.add_argument("--s3_bucket", dest="s3_bucket", type=str, required=True, help="bucket name.")
    parser.add_argument("--file_path", dest="file_path", type=str, required=True, help="the file path of image list.")
    parser.add_argument("--batch_size", dest="batch_size", type=int, default=1, help="batch size.")
    parser.add_argument("--num_workers", dest="num_workers", type=int, default=32,
                        help="num of workers in pytorch dataloader.")
    parser.add_argument("--save_prefix", dest="save_prefix", required=True, help="prefix for saving features.")
    parser.add_argument("--model_path", dest="model_path", required=True, help="model path. in s3")
    args = parser.parse_args()

    start_time = time.time()

    # https://docs.ray.io/en/latest/data/data-internals.html?highlight=lazy#actor-locality-optimization-ml-inference-use-case
    #ctx = ray.data.DataContext.get_current()
    ##By default, this is set to True already.
    #ctx.execution_options.actor_locality_enabled = True

    # copy input to /data on all nodes <= not needed
    #process_all_s3_files_on_all_nodes(args.s3_bucket, args.file_path, "/data/input", if_gpu_only_nodes=False, parallelism=1)

    # copy model to /data on all gpu nodes
    #process_all_s3_files_on_all_nodes(args.s3_bucket, args.model_path, "/data/model", if_gpu_only_nodes=True, parallelism=1)

    # read the dataset
    # You can manually specify the number of read tasks, but the final parallelism is always capped by the number of files in the underlying dataset.
    #ds = ray.data.read_text("s3://" + args.s3_bucket + "/" + args.file_path, ray_remote_args={"resources": {"head_workers": 1}})
    ds = ray.data.read_text("s3://" + args.s3_bucket + "/" + args.file_path, parallelism=100000, ray_remote_args={"resources": {"custom_gpu_vcpu_inference": 1}})
    print(ds)
    # because of the number of cpus as 104, the number of blocks are being reduced to 208, despite having high number of files
    #print("repartitioning to increase parallelism")
    #ds = ds.repartition(1000000)
    #print("done reading input:", ds)
    #print("schmea: ", ds.schema())
    #sys.exit(0)

    # it will contains only 1 block, repartition it into more number of blocks
    #ds = ds.repartition(2000)
    #ds = ds.repartition(2)


    """
    The size of the batches provided to fn may be smaller than the provided batch_size if batch_size doesn’t evenly divide the block(s) sent to a given map task. When batch_size is specified, eachmap task will be sent a single block if the block is equal to or larger than batch_size, and will be sent a bundle of blocks up to (but not exceeding) batch_size if blocks are smaller than batch_size.
    """
    # https://docs.ray.io/en/latest/ray-core/api/doc/ray.actor.ActorClass.options.html?highlight=max_concurrency#ray.actor.ActorClass.options
    # lets keep things deterministic use compute=ActorPoolStrategy(size=10) instead of compute=ActorPoolStrategy(min_size=10, max_size=200)
    ds = ds.map_batches(PreProcessActor, batch_size=2,  batch_format="pandas", compute=ActorPoolStrategy(size=200, max_tasks_in_flight_per_actor=10), num_cpus=0, resources={"custom_gpu_vcpu_inference": 1}, scheduling_strategy="SPREAD")

    ds = ds.map_batches(InferenceActor, batch_size=1, compute=ActorPoolStrategy(size=200, max_tasks_in_flight_per_actor=2), num_gpus=1,
                        batch_format="pandas", fn_constructor_kwargs={"bucket": args.s3_bucket, "prefix": args.model_path, "download_path": "/data/model", "model_path": "/data/model/xlzhu/yellowranger/model-binaries/yr_image_seq_ckpt.pt"}, max_concurrency=1, resources={"custom_gpu_vcpu_inference": 1}, scheduling_strategy="SPREAD")

    ds = ds.map_batches(SaveOutputActor, batch_size=1, compute=ActorPoolStrategy(size=200, max_tasks_in_flight_per_actor=2), num_cpus = 0,
                        batch_format="pandas", fn_constructor_kwargs={"s3_bucket": "a9vs-photon-us-east-1-alpha-training-exp-bucket", "s3_prefix": "xlzhu/redranger/features/image2", "local_download_path": "/data/img_features"}, max_concurrency=1, resources={"custom_gpu_vcpu_inference": 1}, scheduling_strategy="SPREAD")

    #ds.write_csv("s3://a9vs-photon-us-east-1-alpha-training-exp-bucket/xlzhu/redranger/features/mapping-img-to-file", filesystem=S3FileSystem(retry_strategy=AwsDefaultS3RetryStrategy(max_attempts=20)))
    ds.write_csv("s3://a9vs-photon-us-east-1-alpha-training-exp-bucket/xlzhu/redranger/features/mapping-img-to-file")
    #ds.materialize()

    # Watch for stage fusion https://docs.ray.io/en/latest/data/data-internals.html?highlight=lazy#stage-fusion-optimization
    print(ds.stats())
    print(ds)
    """
    # some random model stuff not tried
    model = resnet50(weights=ResNet50_Weights.DEFAULT)
    model_ref = ray.put(model)

    #model = torch.jit.load(args.model_path, map_location=device)
    model = torch.jit.load(args.model_path, map_location='cpu')
    model_ref = ray.put(model)
    #queryloader = val_loader(
    #    args.file_path, args.batch_size, args.num_workers)

    #run_inference(model_ref, queryloader, args.save_prefix, torch.cuda.is_available())
    run_inference(model_ref, args.save_prefix, True)
    """
    # This is not needed
    #request_resources(num_cpus=0)
    print("ALL DONE")