Synchronize multiple ray.remote functions in Python

I’m working on a project where multiple nodes (let’s say 4, arranged in a grid) are engaged in parallel computations across a time span from tau = 0 to 2. Each node reads its own and its neighbors’ initial values, performs computations, and updates a global state dictionary.

The challenge lies in the fact that a node needs to wait until all its neighbors have completed the previous round before moving on to the next one. Currently, I’ve implemented this using a while True loop, which does the job but feels a bit suboptimal.

import ray

ray.init(num_cpus=2)

nodes = ["node_1", "node_2", "node_3", "node_4"]
neighbours = {
    "node_1": ["node_2", "node_3"],
    "node_2": ["node_1", "node_3"],
    "node_3": ["node_1", "node_4"],
    "node_4": ["node_3", "node_2"]
}


@ray.remote
class GlobalState:
    def __init__(self, sim):
        
        self.z = {}
        
        # initial values for tau = 0 and 
        for node in nodes:
            self.z[(0, node)] = 0

    def set_local_z(self, tau, node, new_val):
        self.z[(tau, node)] = new_val

    def get_local_z(self, tau, node):
        return self.z.get((tau, node), None)

    def get_global_z(self):
        return self.z


def own_computation(neighbours_values):
    # do some more computations here
    return sum(neighbours_values.values()) + 1


@ray.remote
def task(global_state, node):
    for tau in range(1, 3):
        # get z values from neighbours from the previous iteration
        neighbours_values = {}
        for neigh in neighbours[node]:
            while True:
                previous_value = ray.get(global_state.get_local_z.remote(tau - 1, neigh))
                if previous_value != None:
                    break
            neighbours_values[neigh] = ray.get(global_state.get_local_z.remote(tau - 1, neigh))


        # node's own computations
        new_val = own_computation(neighbours_values)

        # update own state
        ray.get(global_state.set_local_z.remote(tau, node, new_val))

    return global_state.get_global_z.remote()



def create_tasks(global_state, nodes):
    tasks = [task.remote(global_state, node) for node in nodes]
    return tasks


global_state = GlobalState.remote(nodes)
state_results = ray.get(create_tasks(global_state, nodes))

# Print the final state
result = ray.get(global_state.get_global_z.remote())

for i,e in result.items():
    print(i,e)

ray.shutdown()

I’m curious if there’s a more efficient way to achieve the same result using Ray in Python 3.8.