Client-server with custom env

I’ve got a slurm client server setup that works with “CartPole-v0”, but it does not work with a custom gym environment. I’m registering the environment in the server and the client and access it through the registered string. Has anyone had success with client server and a custom gym environment?

The client gives me this error:

2021-11-16 09:00:47,463	WARNING deprecation.py:34 -- DeprecationWarning: `SampleBatch['is_training']` has been deprecated. Use `SampleBatch.is_training` instead. This will raise an error in the future!
episode 1346675972 (env-idx=f7836c8da76843fc94599c404c200884) started.
Traceback (most recent call last):
  File "./corridor_client.py", line 69, in <module>
    action = client.get_action(eid, obs)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 121, in get_action
    return self.env.get_action(episode_id, observation)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 128, in get_action
    return episode.wait_for_action(observation)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 241, in wait_for_action
    return self.action_queue.get(True, timeout=360.0)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/queue.py", line 178, in get
    raise Empty
_queue.Empty

Here’s my bash script:

#!/bin/bash
#SBATCH ... # some sbatch options requesting 2 compute nodes, one for server and one for client


# Run with sbatch client_server.sh
source virtual_env/bin/activate

# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

# if we detect a space character in the head node IP, we'll
# convert it to an ipv4 address. This step is optional.
if [[ "$head_node_ip" == *" "* ]]; then
IFS=' ' read -ra ADDR <<<"$head_node_ip"
if [[ ${#ADDR[0]} -gt 16 ]]; then
  head_node_ip=${ADDR[1]}
else
  head_node_ip=${ADDR[0]}
fi
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
fi

port=6379
ip_head=$head_node_ip:$port
export ip_head
echo "IP Head: $ip_head"

echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" --output="slurm-%j-HEAD.out" \
  python3 -u ./corridor_server.py --framework=tf --ip-head $ip_head &

# Nodes take a long time to launch on my machine, so I have a 5 minute wait time.
sleep 300

# number of nodes other than the head node
echo "SLURM JOB NUM NODES " $SLURM_JOB_NUM_NODES
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
    node_i=${nodes_array[$i]}
    echo "Starting WORKER $i at $node_i"
    srun --nodes=1 --ntasks=1 -w "$node_i" --output="slurm-%j-$node_i.out" \
      python3 -u ./corridor_client.py --ip-head $ip_head &
    sleep 5
done

wait

Here’s my server script:


import argparse
import os

import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env

from sim.simple_corridor import SimpleCorridor # Custom environment and creator func
def env_creator(env_config):
    env = SimpleCorridor()
    return env  # return an env instance


parser = argparse.ArgumentParser()
parser.add_argument(
    "--run",
    type=str,
    default="DQN"
)
parser.add_argument(
    "--framework",
    choices=["tf", "torch"],
    default="tf",
    help="The DL framework specifier."
)
parser.add_argument(
    '--ip-head',
    type=str,
    default='localhost:9900',
    help='The ip address and port of the remote server.'
)

if __name__ == "__main__":
    print("From the server")
    register_env('SimpleCorridor', env_creator)
    args = parser.parse_args()

    server_address = args.ip_head.split(':')[0]
    server_port = 9900
    print(f'server: {server_address}:{server_port}')

    ray.init()
    connector_config = {
        # Use the connector server to generate experiences.
        "input": (
            lambda ioctx: PolicyServerInput(ioctx, server_address, server_port)
        ),
        # Use a single worker process to run the server.
        "num_workers": 0,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],
        "callbacks": MyCallbacks,
    }

    if args.run == "DQN":
        # Example of using DQN (supports off-policy actions).
        trainer = DQNTrainer(
            env='SimpleCorridor', # Using the registered name
            config=dict(
                connector_config, **{
                    "learning_starts": 100,
                    "timesteps_per_iteration": 200,
                    "framework": args.framework,
                }))
    elif args.run == "PPO":
        # Example of using PPO (does NOT support off-policy actions).
        trainer = PPOTrainer(
            env='SimpleCorridor',, # Using the registered name
            config=dict(
                connector_config, **{
                    "rollout_fragment_length": 1000,
                    "train_batch_size": 4000,
                    "framework": args.framework,
                }))
    else:
        raise ValueError("--run must be DQN or PPO")
    
    print('All done')

    # Serving and training loop.
    while True:
        print(pretty_print(trainer.train()))

And here’s my client script:

import argparse
import gym
from gym.spaces import Discrete, Box
import numpy as np

from ray.rllib.env.policy_client import PolicyClient

from sim.simple_corridor import SimpleCorridor # Same env and  creator func as server
def env_creator():
    env = SimpleCorridor()
    return env  # return an env instance

parser = argparse.ArgumentParser()
parser.add_argument(
    "--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
    "--inference-mode", type=str, default="local", choices=["local", "remote"])
parser.add_argument(
    "--off-policy",
    action="store_true",
    help="Whether to take random instead of on-policy actions.")
parser.add_argument(
    "--stop-reward",
    type=int,
    default=9999,
    help="Stop once the specified reward is reached.")
parser.add_argument(
    "--ip-head",
    type=str,
    default='localhost:9900',
    help="The ip address and port to connect to on the server. This should match the ip_head " \
        "given to the server node, and the port can be incremented if there are multiple " \
        "workers listening on the server."
)

if __name__ == "__main__":
    args = parser.parse_args()
    env = env_creator()

    address, port = args.ip_head.split(':')
    port = 9900
    ip_head = 'http://' + address + ":" + str(port)

    client = PolicyClient(ip_head, inference_mode=args.inference_mode)

    eid = client.start_episode(training_enabled=not args.no_train)
    obs = env.reset()
    rewards = 0

    while True:
        if args.off_policy:
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        else:
            action = client.get_action(eid, obs)
        obs, reward, done, info = env.step(action)
        rewards += reward
        client.log_returns(eid, reward, info=info)
        if done:
            print("Total reward:", rewards)
            if rewards >= args.stop_reward:
                print("Target reward achieved, exiting")
                exit(0)
            rewards = 0
            client.end_episode(eid, obs)
            obs = env.reset()
            eid = client.start_episode(training_enabled=not args.no_train)

Note: To ensure that the environment waits long enough, I changed the timeouts in rllib.env.ExternalEnv to 360 seconds.

I’m using ray 1.8.0

1 Like

I got a working setup, however, it is only used in windows machines. Moreover, I don’t use strings, but actual imports for the environments.

I personally ended up jury-rigging my own way of making it work which is pretty jank looking, but works for my use case. If you want a reference you can look at my stuff over at: GitHub - DenysAshikhin/Underlords at attentionnet
The code is inside the /code folder (policy_client.py) and (policy_server.py)

I ended up rewriting my client and server scripts to match the new examples in ray 1.8.0. After a bit of debugging, it all works now.

1 Like

Hey @rusu24edward , thanks for posting back here that it’s working now for you.

Just to understand, would we need to change our client/server examples in RLlib to make these cases (custom gym env) work? The only difference I see is that you provide a registered env creator instead of a gym env descriptor (string). :thinking:

You also don’t need any env setup anymore on the server side, just provide your observation and action spaces in your config (config['observation_space'] = ..., config['action_space']=...) and RLlib automatically produces a RandomEnv under the hood for you and uses that (there is no need for the server to build actual env instances).

I don’t think the examples need to change at all. After rewriting my scripts to match the new examples and running a couple of trials, everything worked out. Point taken about providing the spaces, which is a nice feature and no longer requires me to use a registered environment.

Not sure if I saw an example showcasing how to do a multi-agent setup, but that may be nice to see, especially with regards to how to determine the observation and action spaces of each agent on the server side.

I did notice that when i ran the scripts locally, the connection was immediate. However, depending on the system with slurm, it can take a few minutes for the a node to fully launch. I ran a bunch of trials on my setup, and I had to sleep the bash script for 3 minutes before launching the clients to ensure that the server had enough time to get up and running. It may be worth adding some notes in the documentation for users to be aware of the need to delay.

I wonder if there’s a way to simply remove the connection timeout altogether. I don’t see the benefit of it because the job will fail if it times out and the job will fail if it doesn’t, so might as well let it keep trying for the duration of the job. Thoughts?