Ray/Plasma backed array

I’m trying to reduce CPU ram usage of a pytorch dataset, namely fairseq TokenBlockDataset by writing arrays to shared memory and then reading array entries from shared memory.
Without shared memory, this would be something like:

class InMemoryArray:
    def __init__(self, array):
        self.array = array
    def read_one(self, i):
        return self.array[i]

but self.array would be stored fully on each GPU.

Is there a way to use Ray such that only 1 rank actually needs to write the array and all workers can read? Or some other way to use ray to better manage memory?

class RayBackedArray:

    def __init__(self, array, rank):
        if rank == 0:
            self.refs = [ray.put(x) for x in array]
        else: self.refs = magically_gather_refs()
    def read_one(self, i):
        return ray.get(self.refs[i])

For reference, In raw pyarrow.plasma, you can invent a hashing scheme where clients can know the object id without calling put, but the writes (I think) are slow and block each other. I have deleted some stuff for brevity, but happy to share working slow plasma example if that’s helpful!

class PlasmaView:
    New callers of this method should read https://tinyurl.com/25xx7j7y to avoid hash collisions

    def __init__(self, array, data_path: str, object_id: int):
        self.object_id = self.get_object_id(data_path, object_id) # hash the data path and a number declared by the caller
        self._client = None  # Initialize lazily for pickle, (TODO: needed?)
        self.use_lock = True
        self.full_shape = array.shape
        self._partial_hash = None
        if array is None: return # another rank will write to the plasma store
        if not self.client.contains(self.object_id):
                for i, x in tqdm(list(enumerate(array))):
                    oid = self._index_to_object_id(i)
                    self.client.put(x, oid)

            except plasma.PlasmaObjectExists:
                self.msg(f"PlasmaObjectExists {oid}")
    def partial_hash(self):
        if self._partial_hash is None:
            hash = hashlib.blake2b(b'0', digest_size=20)
            for dim in self.full_shape:
                hash.update(dim.to_bytes(4, byteorder='big'))
            self._partial_hash = hash
        return self._partial_hash

    def client(self):
        if self._client is None:
            self._client = plasma.connect('/tmp/plasma', num_retries=200)
        return self._client
   def read_one(self, index):
        oid = self._index_to_object_id(index)
        return self._get_with_lock(oid)

    def _index_to_object_id(self, index):
        hash = self.partial_hash.copy()
        return plasma.ObjectID(hash.digest())

Thanks in advance!

Hey @Sam_Shleifer great to see you again!

RE: the example you shared - it looks like the writes only happen upon init? Is your main problem with that example the fact that the tqdm + client.put call is too slow?

My understanding is that it seems like you have some numerical array. If you simply call ray/plasma.put(array), and on other workers you call ray/plasma.get(object_id), you will not duplicate the memory usage.

I have a question though; doesn’t TokenBlockDataset already leverage plasma to reduce memory usage?

If you simply call ray/plasma.put(array) , and on other workers you call ray/plasma.get(object_id) , you will not duplicate the memory usage.

If I call plasma.put(array, object_id) I only get one object_id, so I can’t randomly index into the array.

Great Q. The existing PlasmaArray called by TokenBlockDataset only writes to plasma when a Dataloader pickles a dataset. Then each worker’s instance reads back the full array whenever it is needed. So if DataLoader(num_workers=0) there are no savings. The pickling of np.array was very expensive (or maybe there was more pickling than actual usage). Also the plasma is only shared by multiple workers on the same cuda rank, not across cuda ranks.


If you do array_view = plasma.get(object_id) after the put, then you get a read-only view of the array (assuming it is a numpy array) where you could randomly index and would not use extra worker ram, right? (Maybe I’m missing something here).

plasma.get(object_id) in this example returns a full deserialized np.array, which makes me thing that
plasma.get(object_id)[0] uses lots of worker memory. Is that a bad assumption?

Hypothetically it shouldn’t use a lot of worker memory (since it’s a just a readonly view backed by shared memory). Can you check top/htop to see if that’s the case?

Correction: The np.array returned is read only. Working on benchmarking.

My benchmarking durbango/benchmark_plasma_reads.py at master · sshleifer/durbango · GitHub suggests that

  • client.put uses a lot of worker memory that I can’t manually garbage collect
  • the reading of the full array consumes 0 memory.

So I guess “read only view into shared memory” is some sick thing that I should use more of and I don’t need to do my tqdm write each entry.

Do you know whether plasma can handle simultaneous reads? I get segfault+ terrible traceback when too many workers read the same id from plasma. I can fix it with a lock, but that seems suboptimal.

Anyways, thanks for your help Richard, phenomenally useful. Are you on github sponsors? Alternatively, I can just owe you some retweets.

I get segfault+ terrible traceback when too many workers read the same id from plasma.

Oh uhh that seems bad. I’ve never seen this before - could you provide a traceback here?

No problem! Retweets are an acceptable currency (and probably the most important one for me!)

1 Like

@Sam_Shleifer I’d love to hear what was the issue you were facing.

1 Like

Here are two tracebacks (1 for 1 GPU, another for 2 GPU)

that demonstrate various errors.

I might also be leaving too many plasma connections open.

Hmm, ok. BTW, Ray has moved off of the Arrow-version of Plasma and now vendors its own, and many of these issues may be resolved there (@sangcho and @pcmoritz may have more context here)

You can reproduce the plasma issue without fairseq or CUDA:


  • Put this script at plasma_demo.py
  • Install Dependencies: pip install pyarrow torch numpy
  • run python plasma_demo.py --num-workers 2.

Traceback here

@Sam_Shleifer could you try using the Ray version of Plasma instead?

I dont understand the link you sent. For ray plasma, do you just mean ray.put and ray.get, or something more direct?

Hmm, I meant try using the Ray vendored Plasma (ray/services.py at 1d2136959ff8f8273dfbe5288bb58baa3ff21139 · ray-project/ray · GitHub) instead of the standard Plasma executable