Object store spilling terabytes of data

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

In an attempt to speed up some RLlib algorithms, I’ve been experimenting with using Ray core to do some decompression work with Ray tasks. The relevant part of my code looks like this:

import ray
from queue import Queue
from threading import Thread
from ray._raylet import ObjectRef
from ray.rllib.algorithms.dqn.dqn import DQN

@ray.remote
def decompress(sample_batch: SampleBatchType) -> SampleBatchType:
    sample_batch.decompress_if_needed()
    return sample_batch

class FastDQN(DQN):
    def __init__(self, *args, **kwargs):
        DQN.__init__(self, *args, **kwargs)

        self.decompress_in_queue: Queue[SampleBatchType] = Queue()
        self.decompress_out_queue: Queue[SampleBatchType] = Queue()
        decompress_thread = Thread(
            target=self.run_decompress,
            daemon=True,
        )
        decompress_thread.start()

   
    def run_decompress(self):
        batch_refs: List[ObjectRef[SampleBatchType]] = []
        while True:
            while True:
                try:
                    compressed_batch = self.decompress_in_queue.get(
                        block=len(batch_refs) == 0
                    )
                    batch_refs.append(decompress.remote(compressed_batch))
                except queue.Empty:
                    break
            ready_refs, batch_refs = ray.wait(batch_refs, num_returns=1, timeout=0.1)
            for batch_ref in ready_refs:
                self.decompress_out_queue.put(ray.get(batch_ref))

So various sample batches are passed in to decompress_in_queue, the decompress thread starts a decompress task for each, then waits for it to finish with ray.wait and puts the result into decompress_out_queue.

This works fine except that it constantly spills from the object store:

(raylet) Spilled 6168 MiB, 4 objects, write throughput 143 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.                     
(raylet) Spilled 15128 MiB, 6 objects, write throughput 342 MiB/s.                                                                          
(raylet) Spilled 15691 MiB, 7 objects, write throughput 347 MiB/s.                                                                          
(raylet) Spilled 16944 MiB, 58 objects, write throughput 361 MiB/s.                                                                         
(raylet) Spilled 34301 MiB, 270 objects, write throughput 517 MiB/s.
(raylet) Spilled 66594 MiB, 627 objects, write throughput 632 MiB/s.
(raylet) Spilled 133549 MiB, 1096 objects, write throughput 800 MiB/s.

Note that this is hundreds of gigabytes, and eventually reaches multiple terabytes. I’ve looked in the directory where the spilling is happening inside /tmp, and it definitely seems to be the SampleBatch objects that decompress is consuming and producing. And the problem goes away when I run the decompression without using Ray, so it definitely seems like this is the problem.

Can anyone help me figure out why the code is spilling so much data from the object store? It seems like once the decompress task is done, it should be able to clean the SampleBatch object out of the object store. Are there any tools that could help determine why it isn’t?

Are you using Ray > 2.0?

Ray automatically reference counts all objects generated by .remote in the cluster. (details: Memory Management — Ray 3.0.0.dev0). The object stays in Ray’s shared memory store until the ref goes out of scope from the python.

The spilling usually happens when you have more objects than the capacity of the shared memory (30% of node memory by default).

So when you see excessive spilling, it usually means your reference to Ray’s object reference is not GC’ed on time and stays in memory (leaked) for more than it needs to be. You can figure out what are the leaking point using ray summary objects API.

  1. Start the Ray and Python script with the following env var RAY_record_ref_creation_sites=1. E.g., RAY_record_ref_creation_sites =1 python <script.py> and RAY_record_ref_creation_sites =1 ray start ... if you use ray start API
  2. Call ray summary objects (Ray State CLI — Ray 3.0.0.dev0). This will aggregate all references in the cluster and show you the callsite that has the biggest leaked memory.

Example output

======== Object Summary: 2022-12-20 18:17:44.397197 ========
Stats:
------------------------------------
callsite_enabled: true
total_objects: 1
total_size_mb: 1.430511474609375e-05


Table (group by callsite)
------------------------------------
(task call)  
| a.py:<module>:7
    REF_TYPE_COUNTS     TASK_STATE_COUNTS    TOTAL_NUM_NODES    TOTAL_NUM_WORKERS    TOTAL_OBJECTS    TOTAL_SIZE_MB
--  ------------------  -------------------  -----------------  -------------------  ---------------  ---------------
0   LOCAL_REFERENCE: 1  FINISHED: 1          1                  1                    1                1.43051e-05

To see more details about the reference type, see Memory Management — Ray 3.0.0.dev0 (ray summary objects is the better aggregation API than ray memory).

I am using Ray 3.0.0.dev0. I have created a minimal example that displays the same issues as I am having:

import ray

ray.init()

@ray.remote
def process(data):
    return b"\0" * 100_000_000

while True:
    data = b"\0" * 100_000_000
    data = ray.get(process.remote(data))

The output of ray memory is

======== Object references status: 2022-12-29 15:34:37.079805 ========
Grouping by node address...        Sorting by object size...        Display allentries per group...


--- Summary for node address: 128.32.175.174 ---
Mem Used by Objects  Local References  Pinned        Used by task   Captured in Objects  Actor Handles
0.200000003 GB       1, (0.0 GB)       1, (0.1 GB)   1, (0.100000003 GB)  0, (0.0 GB)          0, (0.0 GB)  

--- Object references for node address: 128.32.175.174 ---
IP Address       PID    Type    Call Site               Status          Size    Reference Type      Object Ref                                              
128.32.175.174   88760  Driver  (task call)  | /home/c  SUBMITTED_TO    ?       LOCAL_REFERENCE     5e87988ffd705f70ffffffffffffffffffffffff0100000001000000
                                assidy/sh/Programs/Pyt  _WORKER                                                                                             
                                hon/rl-theory/test.py:                                                                                                      
                                <module>:13                                                                                                                 

128.32.175.174   88976  Worker  (deserialize task arg)  -               0.1 GB  PINNED_IN_MEMORY    00ffffffffffffffffffffffffffffffffffffff01000000e1020000
                                 test.process                                                                                                               

128.32.175.174   88760  Driver  (task call)  | /home/c  FINISHED        0.100000003 GB  USED_BY_PENDING_TASK  00ffffffffffffffffffffffffffffffffffffff01000000e1020000
                                assidy/sh/Programs/Pyt                                                                                                      
                                hon/rl-theory/test.py:                                                                                                      
                                <module>:13                                                                                                                 

To record callsite information for each ObjectRef created, set env variable RAY_record_ref_creation_sites=1

--- Aggregate object store stats across all nodes ---
Plasma memory usage 10013 MiB, 105 objects, 98.59% full, 90.14% needed
Spilled 744152 MiB, 7803 objects, avg write throughput 6676 MiB/s
Objects consumed by Ray tasks: 140476 MiB.

and the output of ray summary objects is


======== Object Summary: 2022-12-29 15:33:24.340710 ========
Stats:
------------------------------------
callsite_enabled: true
total_objects: 3
total_size_mb: 190.73486614227295


Table (group by callsite)
------------------------------------
(deserialize task arg) test.process
    REF_TYPE_COUNTS      TASK_STATE_COUNTS    TOTAL_NUM_NODES    TOTAL_NUM_WORKERS    TOTAL_OBJECTS    TOTAL_SIZE_MB
--  -------------------  -------------------  -----------------  -------------------  ---------------  ---------------
0   PINNED_IN_MEMORY: 1  '-': 1               1                  1                    1                95.3674



(task call)  
| /home/cassidy/sh/Programs/Python/rl-theory/test.py:<module>:13
    REF_TYPE_COUNTS          TASK_STATE_COUNTS       TOTAL_NUM_NODES    TOTAL_NUM_WORKERS    TOTAL_OBJECTS    TOTAL_SIZE_MB
--  -----------------------  ----------------------  -----------------  -------------------  ---------------  ---------------
0   LOCAL_REFERENCE: 1       FINISHED: 1             1                  1                    2                95.3674
    USED_BY_PENDING_TASK: 1  SUBMITTED_TO_WORKER: 1

From the first, you can see that plasma memory usage is >10GB, even though the program should only be creating 100MB at a time. The memory usage continues to go up the longer it runs and eventually triggers spilling.

This only happens if the process task both consumes and produces data. If it has no arguments but returns data, or has an argument but does not return anything, then the plasma memory usage stays low as expected.

Do you know what’s going on here?

Actually, looking more into this, it seems the problem is that I’m passing data in the argument directly to the remote function instead of creating an ObjectRef first. If I change the above example to

import ray

ray.init()

@ray.remote
def process(data):
    return b"\0" * 100_000_000

while True:
    data = ray.put(b"\0" * 100_000_000)
    data = ray.get(process.remote(data))

then the problem seems to go away. It might be good to somehow describe this caveat in the documentation or throw an error instead of silently eating up all the system memory, though.

Ok, a few more hours into debugging and I have no sense of how memory management works in Ray. For instance, this does not result in a memory leak:

import ray

ray.init()

@ray.remote
def process(data):
    return b"\0" * 100_000_000

while True:
    data = ray.put(b"\0" * 100_000_000)
    ray.get(process.remote(data))

On the other hand, this does:

while True:
    data = ray.put(b"\0" * 100_000_000)
    ref = process.remote(data)
    ray.get(ref)

Here’s an even weirder example. This version does not result in a memory leak:

while True:
    data = ray.put(b"\0" * 100_000_000)
    ref = process.remote(data)
    ray.get(ref)
    del ref

However, adding a del statement makes it start leaking memory:

while True:
    data = ray.put(b"\0" * 100_000_000)
    ref = process.remote(data)
    del data
    ray.get(ref)
    del ref

Can you explain what is going on in these cases?

I feel like I should just abandon trying to use Ray for this use case since I can’t avoid memory leaks through counterintuitive behavior like this.

1 Like

Thanks for reporting this and the detailed reproductions— I filed a P0 issue here: [core] Object store memory leak in simple while loop · Issue #31421 · ray-project/ray · GitHub

It seems like an edge case in the reference counting. Interestingly, “ray memory” shows the reference counting layer correctly freeing the objects, but the objects are leaked at the object store layer nevertheless.

4 Likes

Update: the issue should be resolved in Ray 2.3 ([core] Fix a memory leak due to lineage counting by iycheng · Pull Request #31488 · ray-project/ray · GitHub). If you’d like, you can try out the nightly wheels to check that it fixes your issue.

2 Likes