Client-Server and Tune

I’ve been using the tune API for training reinforcement learning. I’m interested in working across multiple nodes, and I want to lean on the client-server architecture. The example shows the server using the RLlib API where the trainer is explicitly instantiated and uses trainer.train (i.e. not using tune and tune.run), and the client explicitly has the data-generation loop.

What I’ve done so far

I’ve tried this and gotten limited success. I replaced the trainer instantiation with a configuration which I pass to tune:

if __name__ == "__main__":
    # Get the server address and set the port to 9900
    # ...
    ray_tune = {
        'run_or_experiment': 'PPO',
        'checkpoint_freq': 50,
        'checkpoint_at_end': True,
        'stop': {
            'episodes_total': 2000,
        },
        'verbose': 2,
        'config': {
            'env': "CartPole-v0",
            'horizon': 200,
            "num_workers": 0,
            "input_evaluation": [],
            "callbacks": MyCallbacks,
            "rollout_fragment_length": 1000,
            "train_batch_size": 4000,
            "framework": args.framework,
            "input": lambda ioctx: PolicyServerInput(ioctx, server_address, server_port),
        }
    }

    ray.init()
    tune.run(**ray_tune)
    ray.shutdown()

The client script is the same as the client example; it just does the data generation and I’ve added some print statements from debugging.

if __name__ == "__main__":
    # Process args and get ip address of server node
    # ...
    env = gym.make("CartPole-v0")
    client = PolicyClient(ip_head, inference_mode=args.inference_mode)
    print("made client")

    eid = client.start_episode(training_enabled=not args.no_train)
    obs = env.reset()
    rewards = 0

    while True:
        print("Start Step")
        if args.off_policy:
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        else:
            action = client.get_action(eid, obs)
        print(f"action {action}")
        obs, reward, done, info = env.step(action)
        rewards += reward
        client.log_returns(eid, reward, info=info)
        if done:
            print("Total reward:", rewards)
            if rewards >= args.stop_reward:
                print("Target reward achieved, exiting")
                exit(0)
            rewards = 0
            print("Trying to end episode")
            client.end_episode(eid, obs)
            print("Episode ended")
            obs = env.reset()
            print("env reset")
            eid = client.start_episode(training_enabled=not args.no_train)
            print("client started next episode")

The server output looks normal, albeit incomplete.

From the server
2021-08-05 17:04:52,954	INFO services.py:1274 -- View the Ray dashboard at e[1me[32mhttp://127.0.0.1:8265e[39me[22m
== Status ==
Memory usage on this node: 13.6/251.5 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/72 CPUs, 0/2 GPUs, 0.0/165.7 GiB heap, 0.0/75.0 GiB objects (0.0/1.0 accelerator_type:T)
Result logdir: /g/g13/rusu1/ray_results/PPO
Number of trials: 1/1 (1 PENDING)


e[2me[36m(pid=6783)e[0m 2021-08-05 17:05:06,571	INFO trainer.py:671 -- Tip: set framework=tfe or the --eager flag to enable TensorFlow eager execution
e[2me[36m(pid=6783)e[0m 2021-08-05 17:05:06,571	INFO trainer.py:698 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
e[2me[36m(pid=6783)e[0m 2021-08-05 17:05:10,015	WARNING util.py:53 -- Install gputil for GPU system monitoring.
e[2me[36m(pid=6783)e[0m <ip_address> - - [05/Aug/2021 17:05:47] "POST / HTTP/1.1" 200 -
e[2me[36m(pid=6783)e[0m <ip_address> - - [05/Aug/2021 17:05:48] "POST / HTTP/1.1" 200 -

The client shows that we get through the first episode but run into an issue trying to get to the next:

made client
Start Step
episode 1253559585 (env-idx=e31b59f9719c4700b03baed06bc4cb4f) started.
action 1
Start Step
action 1
...
action 1
Total reward: 26.0
Trying to end episode
Episode ended
env reset
client started next episode
Start Steppostprocessed 26 steps

episode 1253559585 (env-idx=e31b59f9719c4700b03baed06bc4cb4f) ended with length 26 and pole angles 0.11107769568480462
Traceback (most recent call last):
  File "./cartpole_client.py", line 64, in <module>
    action = client.get_action(eid, obs)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/policy_client.py", line 121, in get_action
    return self.env.get_action(episode_id, observation)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 128, in get_action
    return episode.wait_for_action(observation)
  File "/usr/WS1/rusu1/abmarl_scale_test/v_ray_test_tf/lib/python3.7/site-packages/ray/rllib/env/external_env.py", line 241, in wait_for_action
    return self.action_queue.get(True, timeout=60.0)
  File "/usr/tce/packages/python/python-3.7.2/lib/python3.7/queue.py", line 178, in get
    raise Empty
_queue.Empty

After the first episode, it can’t seem to make it to the next.

Questions

Is it possible to use the tune API for client-server training?

Also, as I study this more, I realize that I’m not generating the data in parallel. Tune has num_workers which I have used in the past for parallel data generation. But now the data is generated on the client node by the client script, which is a serial for loop in this example. Any thoughts on how to parallelize this data generation process for the simple cartpole example?


This post and this post have a similar title, but they are not the same topic.

Not sure if you are still around but I ran into the _queue.Empty error myself and was stumped on it for months. The culprit is one of 2 things.

  1. The connection gets dropped for whatever reason (outside of RLlib) and the client can’t communicate with server and dies
  2. The observation your are passing in is not the right shape (this took me ages to find out). If you run into queue issue, I would recommend having a check as follows:
print(env.observation_space.contains(obs)```

If you are lucky, it will print false right before the queue.empty error, than you will know that its an obs error