Ray slurm with rllib

Hello,

I’m trying to figure out how to put the ray slurm tutorial together with the cartpole client/server tutorial. I would like to deploy training using rllib’s supported client-server architecture via slurm.

Here’s what I’ve done so far

I’m using an HPC platform, and for testing I’m only requesting 3 nodes: 1 head and 2 workers (I imagine that this translates into 1 remote and 2 clients). I’m attempting to modify the bash script so that

  1. The head node sets up the policy server.
  2. The worker nodes setup policy clients and run training.

My approach may be completely off–please let me know if it is. Any advice is greatly appreciated!

I’ve tried a few things and run into a dead end.

Here’s my slurm bash script

#SLURM <slurm allocation commands>
#...

# Get the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
port=6379
ip_head=$head_node_ip:$port
export ip_head

# Run the server on the head node
srun --nodes=1 --ntasks=1 -w "$head_node" --output="slurm-%j-HEAD.out" \
  python3 -u ./cartpole_server.py --framework=torch --ip-head $ip_head &

# Run the clients on the other nodes
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
    node_i=${nodes_array[$i]}
    echo "Starting WORKER $i at $node_i"
    srun --nodes=1 --ntasks=1 -w "$node_i" --output="slurm-%j-$node_i.out" \
      python3 -u ./cartpole_client.py --ip-head $ip_head &
    sleep 5
done

As you can see, I’m attempting to run the server script on the head node and run the client script on the other nodes.

Here is my server script:

import argparse
import os

import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print

parser = argparse.ArgumentParser()
parser.add_argument(
    "--run",
    type=str,
    default="DQN"
)
parser.add_argument(
    "--framework",
    choices=["tf", "torch"],
    default="tf",
    help="The DL framework specifier."
)
# I added this arg to give the server and client the ip address of the head node
parser.add_argument( 
    '--ip-head',
    type=str,
    default='localhost:9900',
    help='The ip address and port of the remote server.'
)

if __name__ == "__main__":
    print("From the server")
    args = parser.parse_args()
    server_address = args.ip_head.split(':')[0]
    server_port = int(args.ip_head.split(':')[1])

    # --- Setup ray and policy server --- #
    ray.init()
    env = "CartPole-v0"
    connector_config = {
        # Use the connector server to generate experiences.
        "input": (
            lambda ioctx: PolicyServerInput(ioctx, server_address, server_port)
        ),
        # Use a single worker process to run the server.
        "num_workers": 0,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],
        "callbacks": MyCallbacks,
    }

    # Setup trainer and train
    if args.run == "DQN":
        # Example of using DQN (supports off-policy actions).
        trainer = DQNTrainer(
            env=env,
            config=dict(
                connector_config, **{
                    "learning_starts": 100,
                    "timesteps_per_iteration": 200,
                    "framework": args.framework,
                }))
    elif args.run == "PPO":
        # Example of using PPO (does NOT support off-policy actions).
        trainer = PPOTrainer(
            env=env,
            config=dict(
                connector_config, **{
                    "rollout_fragment_length": 1000,
                    "train_batch_size": 4000,
                    "framework": args.framework,
                }))
    else:
        raise ValueError("--run must be DQN or PPO")

    # Serving and training loop.
    while True:
        print(pretty_print(trainer.train()))

My client script is at the bottom.

Below are some errors from this setup and other setups I’ve tried:

Trial and Error

With the above setup, I get this error:

At trainer = DQNTrainer(env=env,config=dict(connector_config, **{"learning_starts": 100,"timesteps_per_iteration": 200,"framework": args.framework}))
...
OSError: [Errno 98] Address already in use`

I get the same error setting the server address to “localhost”.

If I remove ray.init, I get

At client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)
...
requests.exceptions.InvalidSchema: No connection adapters were found for <ip:port>

If I do ray.init(address=<ip:port>), I get

At ray.init(address=args.ip_head)
...
ConnectionRefusedError: [Errno 111] Connection refused

If I try using ray.util.connect(<ip:port>) instead of ray.init(), I get

At client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)
...
requests.exceptions.InvalidSchema: No connection adapters were found for '192.168.128.10:6379'

I see these errors with both tensorflow and pytorch frameworks.

What I think is happening

It looks like the PolicyServer does not expect there to be an instance of ray running at the address. Not sure how to work around this because I need to do ray.init().

Any thoughts?

Client script:

"""Example of training with a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
    $ python cartpole_server.py --run=[PPO|DQN]
    $ python cartpole_client.py --inference-mode=local|remote
Local inference mode offloads inference to the client for better performance.
"""

import argparse
import gym

from ray.rllib.env.policy_client import PolicyClient

parser = argparse.ArgumentParser()
parser.add_argument(
    "--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
    "--inference-mode", type=str, default="local", choices=["local", "remote"])
parser.add_argument(
    "--off-policy",
    action="store_true",
    help="Whether to take random instead of on-policy actions.")
parser.add_argument(
    "--stop-reward",
    type=int,
    default=9999,
    help="Stop once the specified reward is reached.")
parser.add_argument(
    "--ip-head",
    type=str,
    default='localhost:9900',
    help="The ip address and port to connect to on the server. This should match the ip_head " \
        "given to the server node, and the port can be incremented if there are multiple " \
        "workers listening on the server."
)

if __name__ == "__main__":
    args = parser.parse_args()
    env = gym.make("CartPole-v0")

    client = PolicyClient(args.ip_head, inference_mode=args.inference_mode)

    eid = client.start_episode(training_enabled=not args.no_train)
    obs = env.reset()
    rewards = 0

    while True:
        if args.off_policy:
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        else:
            action = client.get_action(eid, obs)
        obs, reward, done, info = env.step(action)
        rewards += reward
        client.log_returns(eid, reward, info=info)
        if done:
            print("Total reward:", rewards)
            if rewards >= args.stop_reward:
                print("Target reward achieved, exiting")
                exit(0)
            rewards = 0
            client.end_episode(eid, obs)
            obs = env.reset()
            eid = client.start_episode(training_enabled=not args.no_train)

Update: I just tried setting the PolicyServer with a different port, namely ip_port + 1. This allowed me to do ray.init() and create the PolicyServer. However, now my client is giving an error: requests.exceptions.InvalidSchema: No connection adapters were found for '<ip:port>'. I tried it both with the node’s port and the port + 1, same error.

Hey @rusu24edward , the PolicyClient only needs to be pointed to one of the ports the PolicyServer is listening on (NOT the ray port 6378!!). The PolicyServer created in your script - by default - listens on port(s): localhost:9900 as well as 9901, 9902, etc… in case you set num_workers=2, 3, etc… on the server. However, in your slurm script, you set this port to be 6378, which is ray’s port and the Server cannot listen on that (it’s already taken by ray).
Let me know, if I missed something, but I think that’s the root cause here: confusion of ray port vs client/server external env port(s).

Hi @sven1977, thanks for the response. I did try manually setting the port to a different value for both the server and the client, but I haven’t gotten that to work either. The slurm script is the same as above, but now the server script is like this:

if __name__ == '__main__':
    args = parser.parse_args()
    server_address = args.ip_head.split(':')[0]
    server_port = 9900
    ray.init()

    env = "CartPole-v0"
    connector_config = {
        # Use the connector server to generate experiences.
        "input": (
            lambda ioctx: PolicyServerInput(ioctx, server_address, server_port)
        ),
    ...

and the client script is

if __name__ == "__main__":
    args = parser.parse_args()
    env = gym.make("CartPole-v0")

    address, port = args.ip_head.split(':')
    port = 9900
    ip_head = address + ":" + str(port)

    client = PolicyClient(ip_head, inference_mode=args.inference_mode)
    ...

The server setups up correctly, but the client gives me this error:

Traceback (most recent call last):
  File "./cartpole_client.py", line 51, in <module>
    client = PolicyClient(ip_head, inference_mode=args.inference_mode)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 65, in __init__
    self._setup_local_rollout_worker(update_interval)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 229, in _setup_local_rollout_worker
    "command": PolicyClient.GET_WORKER_ARGS,
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 216, in _send
    response = requests.post(self.address, data=payload)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/site-packages/requests/api.py", line 116, in post
    return request('post', url, data=data, json=json, **kwargs)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/site-packages/requests/api.py", line 60, in request
    return session.request(method=method, url=url, **kwargs)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/site-packages/requests/sessions.py", line 533, in request
    resp = self.send(prep, **send_kwargs)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/site-packages/requests/sessions.py", line 640, in send
    adapter = self.get_adapter(url=request.url)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/site-packages/requests/sessions.py", line 731, in get_adapter
    raise InvalidSchema("No connection adapters were found for '%s'" % url)
requests.exceptions.InvalidSchema: No connection adapters were found for '<ip_address>:9900'

where <ip_address> shows me the actual ip address of the server node.

I also used a variation where I set the address to 'localhost'. Same error.

Any help is greatly appreciated!

1 Like

I added http:// to the beginning of the address for the PolicyClient, and now it works. Thanks for your help!

1 Like

Awesome @rusu24edward ! Thank for posting the solution here. :slight_smile: