I’m trying to reduce CPU ram usage of a pytorch dataset, namely fairseq TokenBlockDataset
by writing arrays to shared memory and then reading array entries from shared memory.
Without shared memory, this would be something like:
class InMemoryArray:
def __init__(self, array):
self.array = array
def read_one(self, i):
return self.array[i]
but self.array
would be stored fully on each GPU.
Is there a way to use Ray such that only 1 rank actually needs to write the array and all workers can read? Or some other way to use ray to better manage memory?
class RayBackedArray:
def __init__(self, array, rank):
if rank == 0:
self.refs = [ray.put(x) for x in array]
else: self.refs = magically_gather_refs()
def read_one(self, i):
return ray.get(self.refs[i])
For reference, In raw pyarrow.plasma
, you can invent a hashing scheme where clients can know the object id without calling put
, but the writes (I think) are slow and block each other. I have deleted some stuff for brevity, but happy to share working slow plasma example if that’s helpful!
class PlasmaView:
"""
New callers of this method should read https://tinyurl.com/25xx7j7y to avoid hash collisions
"""
def __init__(self, array, data_path: str, object_id: int):
self.object_id = self.get_object_id(data_path, object_id) # hash the data path and a number declared by the caller
self._client = None # Initialize lazily for pickle, (TODO: needed?)
self.use_lock = True
self.full_shape = array.shape
self._partial_hash = None
if array is None: return # another rank will write to the plasma store
if not self.client.contains(self.object_id):
try:
for i, x in tqdm(list(enumerate(array))):
oid = self._index_to_object_id(i)
self.client.put(x, oid)
except plasma.PlasmaObjectExists:
self.msg(f"PlasmaObjectExists {oid}")
@property
def partial_hash(self):
if self._partial_hash is None:
hash = hashlib.blake2b(b'0', digest_size=20)
for dim in self.full_shape:
hash.update(dim.to_bytes(4, byteorder='big'))
self._partial_hash = hash
return self._partial_hash
@property
def client(self):
if self._client is None:
self._client = plasma.connect('/tmp/plasma', num_retries=200)
return self._client
def read_one(self, index):
oid = self._index_to_object_id(index)
return self._get_with_lock(oid)
def _index_to_object_id(self, index):
hash = self.partial_hash.copy()
hash.update(self.int_to_bytes(int(index)))
return plasma.ObjectID(hash.digest())
Thanks in advance!