- High: It blocks me to complete my task.
I am getting different reward distributions when I use the Python API vs the rllib CLI on a PPO checkpoint. Below is a comparison of reward histograms produced by each method. What is causing this large discrepency?
Bin Edges: [ 0. 25. 50. 75. 100. 125. 150. 175. 200. 225. 250. 275. 300. 325. 350. 375. 400. 425. 450. 475. 500. 525.]
CLI Histogram: [ 0 0 0 0 0 0 0 0 1 0 3 6 1 4 3 2 2 2 2 3 71]
Python Histogram: [ 0 0 0 0 0 42 41 10 5 1 0 0 0 0 1 0 0 0 0 0 0]
To train I’m using:
$ rllib train file cartpole-ppo.yaml
Where cartpole-ppo.yaml is the one provided in the tuned_examples directory:
cartpole-ppo-troubleshoot:
env: CartPole-v1
run: PPO
stop:
sampler_results/episode_reward_mean: 150
timesteps_total: 100000
config:
# Works for both torch and tf.
framework: torch
gamma: 0.99
lr: 0.0003
num_workers: 1
observation_filter: MeanStdFilter
num_sgd_iter: 6
vf_loss_coeff: 0.01
model:
fcnet_hiddens: [32]
fcnet_activation: linear
vf_share_layers: true
enable_connectors: true
The checkpoint is evaluated using the CLI and episode rewards are printed to stdout:
$ rllib evaluate --algo PPO --episodes 100 --steps 0 [checkpoint_path]
The same checkpoint is also evaluated using the Python API like this (Note: setting exploration on/off does not resolve the discrepancy):
def evaluate_model_python_api(checkpoint_path, n_episodes=100):
env_name = "CartPole-v1"
env = gym.make(env_name)
algo = Algorithm.from_checkpoint(checkpoint_path)
episode_rewards = []
for _ in range(n_episodes):
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
while not terminated and not truncated:
action = algo.compute_single_action(obs, explore=False)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
episode_rewards.append(episode_reward)
return episode_rewards
Here is the full script that trains, evaluates, scrapes data from stdout, and creates the histograms. Repeating this process with DQN shows matching histograms for each evaluation method.
import gymnasium as gym
from ray.rllib.algorithms import Algorithm
import os
import subprocess
from glob import glob
import numpy as np
def train(yaml_path):
cmd_str = "rllib train file {} ".format(yaml_path)
print(cmd_str)
os.system(cmd_str)
def latest_checkpoint(root_path="~/ray_results"):
results_dir = os.path.expanduser(root_path)
checkpoints = glob(os.path.join(results_dir, "**","checkpoint*"), recursive=True)
latest_checkpoints = max(checkpoints, key=os.path.getmtime)
return latest_checkpoints
def evaluate_model_cli(checkpoint_path, n_episodes=100):
cmd = ["rllib", "evaluate", "--algo", "PPO", "--episodes", str(n_episodes), "--steps", "0", checkpoint_path]
print(" ".join(cmd))
stdout_str = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8')
stdout_rows = stdout_str.split("\n")
episode_rewards = [float(row.split(":")[-1]) for row in stdout_rows if row != ""]
return episode_rewards
def evaluate_model_python_api(checkpoint_path, n_episodes=100):
env_name = "CartPole-v1"
env = gym.make(env_name)
algo = Algorithm.from_checkpoint(checkpoint_path)
episode_rewards = []
for _ in range(n_episodes):
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
while not terminated and not truncated:
action = algo.compute_single_action(obs, explore=False)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
episode_rewards.append(episode_reward)
return episode_rewards
if __name__ == '__main__':
yaml_dir = os.path.realpath(os.path.dirname(__file__))
yaml_path = os.path.join(yaml_dir, "cartpole_ppo_example.yaml")
train(yaml_path)
checkpoint_path = latest_checkpoint()
cli_rewards = evaluate_model_cli(checkpoint_path)
python_rewards = evaluate_model_python_api(checkpoint_path)
hist_range = [0, 525]
bins = 21
cli_hist, bin_edges = np.histogram(cli_rewards, bins=bins, range=hist_range)
python_hist, bin_edges = np.histogram(python_rewards, bins=bins, range=hist_range)
print("\n\n")
print("Bin Edges: {}".format(bin_edges))
print("CLI Histogram: {}".format(cli_hist))
print("Python Histogram: {}".format(python_hist))