I’ve tried a few things and run into a dead end.
Here’s my slurm bash script
#SLURM <slurm allocation commands>
#...
# Get the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
port=6379
ip_head=$head_node_ip:$port
export ip_head
# Run the server on the head node
srun --nodes=1 --ntasks=1 -w "$head_node" --output="slurm-%j-HEAD.out" \
python3 -u ./cartpole_server.py --framework=torch --ip-head $ip_head &
# Run the clients on the other nodes
worker_num=$((SLURM_JOB_NUM_NODES - 1))
for ((i = 1; i <= worker_num; i++)); do
node_i=${nodes_array[$i]}
echo "Starting WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w "$node_i" --output="slurm-%j-$node_i.out" \
python3 -u ./cartpole_client.py --ip-head $ip_head &
sleep 5
done
As you can see, I’m attempting to run the server script on the head node and run the client script on the other nodes.
Here is my server script:
import argparse
import os
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="DQN"
)
parser.add_argument(
"--framework",
choices=["tf", "torch"],
default="tf",
help="The DL framework specifier."
)
# I added this arg to give the server and client the ip address of the head node
parser.add_argument(
'--ip-head',
type=str,
default='localhost:9900',
help='The ip address and port of the remote server.'
)
if __name__ == "__main__":
print("From the server")
args = parser.parse_args()
server_address = args.ip_head.split(':')[0]
server_port = int(args.ip_head.split(':')[1])
# --- Setup ray and policy server --- #
ray.init()
env = "CartPole-v0"
connector_config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, server_address, server_port)
),
# Use a single worker process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"callbacks": MyCallbacks,
}
# Setup trainer and train
if args.run == "DQN":
# Example of using DQN (supports off-policy actions).
trainer = DQNTrainer(
env=env,
config=dict(
connector_config, **{
"learning_starts": 100,
"timesteps_per_iteration": 200,
"framework": args.framework,
}))
elif args.run == "PPO":
# Example of using PPO (does NOT support off-policy actions).
trainer = PPOTrainer(
env=env,
config=dict(
connector_config, **{
"rollout_fragment_length": 1000,
"train_batch_size": 4000,
"framework": args.framework,
}))
else:
raise ValueError("--run must be DQN or PPO")
# Serving and training loop.
while True:
print(pretty_print(trainer.train()))
My client script is at the bottom.
Below are some errors from this setup and other setups I’ve tried:
Trial and Error
With the above setup, I get this error:
At trainer = DQNTrainer(env=env,config=dict(connector_config, **{"learning_starts": 100,"timesteps_per_iteration": 200,"framework": args.framework}))
...
OSError: [Errno 98] Address already in use`
I get the same error setting the server address to “localhost”.
If I remove ray.init
, I get
At client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)
...
requests.exceptions.InvalidSchema: No connection adapters were found for <ip:port>
If I do ray.init(address=<ip:port>)
, I get
At ray.init(address=args.ip_head)
...
ConnectionRefusedError: [Errno 111] Connection refused
If I try using ray.util.connect(<ip:port>)
instead of ray.init()
, I get
At client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)
...
requests.exceptions.InvalidSchema: No connection adapters were found for '192.168.128.10:6379'
I see these errors with both tensorflow and pytorch frameworks.
What I think is happening
It looks like the PolicyServer does not expect there to be an instance of ray running at the address. Not sure how to work around this because I need to do ray.init()
.
Any thoughts?
Client script:
"""Example of training with a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py --run=[PPO|DQN]
$ python cartpole_client.py --inference-mode=local|remote
Local inference mode offloads inference to the client for better performance.
"""
import argparse
import gym
from ray.rllib.env.policy_client import PolicyClient
parser = argparse.ArgumentParser()
parser.add_argument(
"--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
"--inference-mode", type=str, default="local", choices=["local", "remote"])
parser.add_argument(
"--off-policy",
action="store_true",
help="Whether to take random instead of on-policy actions.")
parser.add_argument(
"--stop-reward",
type=int,
default=9999,
help="Stop once the specified reward is reached.")
parser.add_argument(
"--ip-head",
type=str,
default='localhost:9900',
help="The ip address and port to connect to on the server. This should match the ip_head " \
"given to the server node, and the port can be incremented if there are multiple " \
"workers listening on the server."
)
if __name__ == "__main__":
args = parser.parse_args()
env = gym.make("CartPole-v0")
client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)
eid = client.start_episode(training_enabled=not args.no_train)
obs = env.reset()
rewards = 0
while True:
if args.off_policy:
action = env.action_space.sample()
client.log_action(eid, obs, action)
else:
action = client.get_action(eid, obs)
obs, reward, done, info = env.step(action)
rewards += reward
client.log_returns(eid, reward, info=info)
if done:
print("Total reward:", rewards)
if rewards >= args.stop_reward:
print("Target reward achieved, exiting")
exit(0)
rewards = 0
client.end_episode(eid, obs)
obs = env.reset()
eid = client.start_episode(training_enabled=not args.no_train)