Pulling ObjectRefs from multiple ObjectRefGenerator - Nonblocking

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity

Hi there, sorry I’m new to Ray still and i’m trying to learn. My question is regarding non-blocking retrieval of ObjectRefs from multiple ObjectRefGenerators.

I’ve attempted multiple things with the closest so far being:

generators = [worker.self_play.remote() for worker in sp_workers]
while True:
    tasks = [next(gen) for gen in generators]
    results = ray.get(tasks)
    for r in results:
        print(f"Counter: {r[0]} | {r[1]} - {r[2].shape}")

self-play is a bit of a dummy function at the moment. all it does is generate a number of flat tensors and stack them. The problem i’m having with the above is that it blocks to the slowest generators yield.

The scenario is this. Each self-play actor vectorises an internal model to play multiple games at the same time. Because games end at differing points the plan is to yield the complete games when they complete and allow the other games to continue playing (with the games being yielded being reset). thus we should have a never ending series of games which we’re attempting to stream into a replay buffer. The replay buffer can then be used by a trainer to update the models weights. To improve efficiency, we init multiple self-play workers…

Since each actors method yields the result, that means we have N possible ObjectRefGenerators at a time. Ideally when one yields, i would like to immediately pull that result and place it into the buffer.

I’ve looked over the patterns in the docs, including using generators to reduce heap memory usage, as well as checking out asyncio as well. However i’m equally new to asyncio so there is a possibility I’ve read the answer and didn’t realise it.

Any guidance would be greatly appreciated.

Thank you.

Edit: Further experimentation

I’ve been working on this problem and im working through the example given by AsyncIO / Concurrency for Actors — Ray 2.38.0

the new idea is to implement the replaybuffer with the data generation as an async task. and then pull directly from the replay buffer.

However, the above documentation suggests needing to wrap an async call in a function.

# async ray.get
async def async_get():
    await actor.run_concurrent.remote()
asyncio.run(async_get())

but when i try to run the whole code in jupyter,

# regular ray.get
ray.get([actor.run_concurrent.remote() for _ in range(4)])

appears to run fine… what am i missing?

I think i may have cracked it with the following. If someone could provide feedback, that’d be very much appreciated.

@ray.remote(num_gpus=0.25)
class Policy_Server:
    def __init__(self):
        self.device = t.device("cuda" if t.cuda.is_available() else "cpu")
        self.model = net().to(self.device)
        self.model.eval()

    def update_model_parameters(self, state_dict: Dict[str, Tensor]) -> str:
        with self.lock:
            self.model.load_state_dict(state_dict=state_dict)
            return "Model Parameters updated..."

    async def inference(
        self, observation: Union[Tensor, np.ndarray]
    ) -> Tuple[np.ndarray, np.ndarray]:
        batch = t.tensor(observation).float().to(self.device)
        with t.no_grad():
            p, v = self.model(batch)
            return p.detach().cpu().numpy(), v.detach().cpu().numpy().flatten()

    def model_init(self, path: Path = Path("./Checkpoints/best_model.pth")):
        self.update_model_parameters(t.load(path, weights_only=True))

This is just a remote neural network for calling inference for self play workers.

@ray.remote
class ReplayBuffer:
    def __init__(
        self,
        policy_server,
        capacity: int = 10000,
    ):
        self.buffer = deque(maxlen=capacity)
        self.size = capacity

        self.policy_server = policy_server
        self.running = True

    async def add(self, experience, index):
        self.buffer.extend(experience)
        print(f"index={index} | buffer size:{len(self.buffer)}", flush=True)
    
    async def self_play(self, index):
        """Selfplay logic goes here, below is example code."""
        while self.running:
            data = []
            range_of_index = np.random.randint(5, 20)
            for _ in range(range_of_index):
                time.sleep(1)
                s = np.random.randint(0, 2, size=(111, 8, 8))
                ref = self.policy_server.inference.remote(s)
                fut: asyncio.Future = asyncio.wrap_future(ref.future())
                p,v = await fut
                data.append(np.hstack((index,s.flatten(), p.flatten(), v.flatten())))

            result = np.vstack(data)
            print(f"Index={index} | shape: {result.shape}", flush=True)
            await self.add(result,index)

    def stop_play(self):
        self.running = False

    def start_play(self):
        self.running = True

    async def length(self) -> int:
        return len(self.buffer)

The above is a replay buffer class. with a self play task to update the buffer.

capacity = 10000
ps = Policy_Server.remote()
replay_buffer = ReplayBuffer.options(max_concurrency=6).remote(capacity=capacity,policy_server=ps, num_boards=0, num_reads=0)

async def async_get(index):
    await replay_buffer.self_play.remote(index)

async def main():
    tasks  = [async_get(_) for _ in range(5)]
    await asyncio.gather(*tasks)

await main()

this is providing me with an output that looks like:

2024-11-03 19:07:04,728	INFO worker.py:1777 -- Started a local Ray instance. View the dashboard at http://x.x.x.x:8265 
(ReplayBuffer pid=955374) Index=3 | shape: (9, 11778)
(ReplayBuffer pid=955374) index=3 | 9
(ReplayBuffer pid=955374) Index=0 | shape: (11, 11778)
(ReplayBuffer pid=955374) index=0 | 20
(ReplayBuffer pid=955374) Index=4 | shape: (12, 11778)
(ReplayBuffer pid=955374) index=4 | 32
(ReplayBuffer pid=955374) Index=2 | shape: (17, 11778)
(ReplayBuffer pid=955374) index=2 | 49
(ReplayBuffer pid=955374) Index=1 | shape: (18, 11778)
(ReplayBuffer pid=955374) index=1 | 67
(ReplayBuffer pid=955374) Index=4 | shape: (12, 11778)
(ReplayBuffer pid=955374) index=4 | 79
(ReplayBuffer pid=955374) Index=3 | shape: (16, 11778)
(ReplayBuffer pid=955374) index=3 | 95
(ReplayBuffer pid=955374) Index=0 | shape: (17, 11778)
(ReplayBuffer pid=955374) index=0 | 112
(ReplayBuffer pid=955374) Index=2 | shape: (11, 11778)

One thing i am noticing is. it’s not the fastest…I would expect it to be faster. - 645 observations in 10 min doesn’t sound right.