I’m working on a project where multiple nodes (let’s say 4, arranged in a grid) are engaged in parallel computations across a time span from tau = 0 to 2. Each node reads its own and its neighbors’ initial values, performs computations, and updates a global state dictionary.
The challenge lies in the fact that a node needs to wait until all its neighbors have completed the previous round before moving on to the next one. Currently, I’ve implemented this using a while True loop, which does the job but feels a bit suboptimal.
import ray
ray.init(num_cpus=2)
nodes = ["node_1", "node_2", "node_3", "node_4"]
neighbours = {
"node_1": ["node_2", "node_3"],
"node_2": ["node_1", "node_3"],
"node_3": ["node_1", "node_4"],
"node_4": ["node_3", "node_2"]
}
@ray.remote
class GlobalState:
def __init__(self, sim):
self.z = {}
# initial values for tau = 0 and
for node in nodes:
self.z[(0, node)] = 0
def set_local_z(self, tau, node, new_val):
self.z[(tau, node)] = new_val
def get_local_z(self, tau, node):
return self.z.get((tau, node), None)
def get_global_z(self):
return self.z
def own_computation(neighbours_values):
# do some more computations here
return sum(neighbours_values.values()) + 1
@ray.remote
def task(global_state, node):
for tau in range(1, 3):
# get z values from neighbours from the previous iteration
neighbours_values = {}
for neigh in neighbours[node]:
while True:
previous_value = ray.get(global_state.get_local_z.remote(tau - 1, neigh))
if previous_value != None:
break
neighbours_values[neigh] = ray.get(global_state.get_local_z.remote(tau - 1, neigh))
# node's own computations
new_val = own_computation(neighbours_values)
# update own state
ray.get(global_state.set_local_z.remote(tau, node, new_val))
return global_state.get_global_z.remote()
def create_tasks(global_state, nodes):
tasks = [task.remote(global_state, node) for node in nodes]
return tasks
global_state = GlobalState.remote(nodes)
state_results = ray.get(create_tasks(global_state, nodes))
# Print the final state
result = ray.get(global_state.get_global_z.remote())
for i,e in result.items():
print(i,e)
ray.shutdown()
I’m curious if there’s a more efficient way to achieve the same result using Ray in Python 3.8.