I’ve got a slurm client server setup that works with “CartPole-v0”, but it does not work with a custom gym environment. I’m registering the environment in the server and the client and access it through the registered string. Has anyone had success with client server and a custom gym environment?
The client gives me this error:
2021-11-16 09:00:47,463 WARNING deprecation.py:34 -- DeprecationWarning: `SampleBatch['is_training']` has been deprecated. Use `SampleBatch.is_training` instead. This will raise an error in the future!
episode 1346675972 (env-idx=f7836c8da76843fc94599c404c200884) started.
Traceback (most recent call last):
File "./corridor_client.py", line 69, in <module>
action = client.get_action(eid, obs)
File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 121, in get_action
return self.env.get_action(episode_id, observation)
File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 128, in get_action
return episode.wait_for_action(observation)
File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 241, in wait_for_action
return self.action_queue.get(True, timeout=360.0)
File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/queue.py", line 178, in get
raise Empty
_queue.Empty
Here’s my bash script:
#!/bin/bash
#SBATCH ... # some sbatch options requesting 2 compute nodes, one for server and one for client
# Run with sbatch client_server.sh
source virtual_env/bin/activate
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
# if we detect a space character in the head node IP, we'll
# convert it to an ipv4 address. This step is optional.
if [[ "$head_node_ip" == *" "* ]]; then
IFS=' ' read -ra ADDR <<<"$head_node_ip"
if [[ ${#ADDR[0]} -gt 16 ]]; then
head_node_ip=${ADDR[1]}
else
head_node_ip=${ADDR[0]}
fi
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
fi
port=6379
ip_head=$head_node_ip:$port
export ip_head
echo "IP Head: $ip_head"
echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" --output="slurm-%j-HEAD.out" \
python3 -u ./corridor_server.py --framework=tf --ip-head $ip_head &
# Nodes take a long time to launch on my machine, so I have a 5 minute wait time.
sleep 300
# number of nodes other than the head node
echo "SLURM JOB NUM NODES " $SLURM_JOB_NUM_NODES
worker_num=$((SLURM_JOB_NUM_NODES - 1))
for ((i = 1; i <= worker_num; i++)); do
node_i=${nodes_array[$i]}
echo "Starting WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w "$node_i" --output="slurm-%j-$node_i.out" \
python3 -u ./corridor_client.py --ip-head $ip_head &
sleep 5
done
wait
Here’s my server script:
import argparse
import os
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
from sim.simple_corridor import SimpleCorridor # Custom environment and creator func
def env_creator(env_config):
env = SimpleCorridor()
return env # return an env instance
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="DQN"
)
parser.add_argument(
"--framework",
choices=["tf", "torch"],
default="tf",
help="The DL framework specifier."
)
parser.add_argument(
'--ip-head',
type=str,
default='localhost:9900',
help='The ip address and port of the remote server.'
)
if __name__ == "__main__":
print("From the server")
register_env('SimpleCorridor', env_creator)
args = parser.parse_args()
server_address = args.ip_head.split(':')[0]
server_port = 9900
print(f'server: {server_address}:{server_port}')
ray.init()
connector_config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, server_address, server_port)
),
# Use a single worker process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"callbacks": MyCallbacks,
}
if args.run == "DQN":
# Example of using DQN (supports off-policy actions).
trainer = DQNTrainer(
env='SimpleCorridor', # Using the registered name
config=dict(
connector_config, **{
"learning_starts": 100,
"timesteps_per_iteration": 200,
"framework": args.framework,
}))
elif args.run == "PPO":
# Example of using PPO (does NOT support off-policy actions).
trainer = PPOTrainer(
env='SimpleCorridor',, # Using the registered name
config=dict(
connector_config, **{
"rollout_fragment_length": 1000,
"train_batch_size": 4000,
"framework": args.framework,
}))
else:
raise ValueError("--run must be DQN or PPO")
print('All done')
# Serving and training loop.
while True:
print(pretty_print(trainer.train()))
And here’s my client script:
import argparse
import gym
from gym.spaces import Discrete, Box
import numpy as np
from ray.rllib.env.policy_client import PolicyClient
from sim.simple_corridor import SimpleCorridor # Same env and creator func as server
def env_creator():
env = SimpleCorridor()
return env # return an env instance
parser = argparse.ArgumentParser()
parser.add_argument(
"--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
"--inference-mode", type=str, default="local", choices=["local", "remote"])
parser.add_argument(
"--off-policy",
action="store_true",
help="Whether to take random instead of on-policy actions.")
parser.add_argument(
"--stop-reward",
type=int,
default=9999,
help="Stop once the specified reward is reached.")
parser.add_argument(
"--ip-head",
type=str,
default='localhost:9900',
help="The ip address and port to connect to on the server. This should match the ip_head " \
"given to the server node, and the port can be incremented if there are multiple " \
"workers listening on the server."
)
if __name__ == "__main__":
args = parser.parse_args()
env = env_creator()
address, port = args.ip_head.split(':')
port = 9900
ip_head = 'http://' + address + ":" + str(port)
client = PolicyClient(ip_head, inference_mode=args.inference_mode)
eid = client.start_episode(training_enabled=not args.no_train)
obs = env.reset()
rewards = 0
while True:
if args.off_policy:
action = env.action_space.sample()
client.log_action(eid, obs, action)
else:
action = client.get_action(eid, obs)
obs, reward, done, info = env.step(action)
rewards += reward
client.log_returns(eid, reward, info=info)
if done:
print("Total reward:", rewards)
if rewards >= args.stop_reward:
print("Target reward achieved, exiting")
exit(0)
rewards = 0
client.end_episode(eid, obs)
obs = env.reset()
eid = client.start_episode(training_enabled=not args.no_train)
Note: To ensure that the environment waits long enough, I changed the timeouts in rllib.env.ExternalEnv to 360 seconds.
I’m using ray 1.8.0