Unable to replicate original PPO performance

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I can’t seem to replicate the original PPO performance (paper) when using RLlib. The hyperparameters used is listed below. It follow the hyperparameters discussed here in aims to replicate the results from the PPO paper (without LSTM).

Hyperparameters

# Environment
Max Frames Per Episode = 108000
Frameskip = 4
Max Of Last 2 Frames = True
Max Steps Per Episode = 27000
Framestack = 4

Observation Type = Grayscale
Frame Size = 84 x 84

Max No Operation Actions = 30
Repeat Action Probability = 0.0

Terminal On Life Loss = True
Fire Action on Reset = True
Reward Clip = {-1, 0, 1}
Full Action Space = False

# Algorithm
Neural Network Feature Extractor = Nature CNN
Neural Network Policy Head = Linear Layer with n_actions output features
Neural Network Value Head = Linear Layer with 1 output feature
Shared Feature Extractor = True
Orthogonal Initialization = True
Scale Images to [0, 1] = True
Optimizer = Adam with 1e-5 Epsilon

Learning Rate = 2.5e-4
Decay Learning Rate = True

Number of Environments = 8
Number of Steps = 128
Batch Size = 256
Number of Minibatches = 4
Number of Epochs = 4
Gamma = 0.99
GAE Lambda = 0.95
Clip Range = 0.1
VF Clip Range = 0.1
Normalize Advantage = True
Entropy Coefficient = 0.01
VF Coefficient = 0.5
Max Gradient Normalization = 0.5
Use Target KL = False
Total Timesteps = 10000000
Log Interval = 1
Evaluation Episodes = 100
Deterministic Evaluation = False

Seed = Random
Number of Trials = 5

I have tried these same hyperparameters with the Baselines, Stable Baselines3, and CleanRL implementations of the PPO algorithm and they all achieve the expected results. For example, on the Alien and Pong environment, the agents are able to achieve more than ~1000 and ~20 mean reward respectively. However, the RLlib agent fails to train at all. The trained RLlib agent achieves ~300 and ~-20 reward for the Alien and Pong environment respectively. Am I missing something in my RLlib configuration (shown below) or is there a bug?

RLlib Code
Python version: 3.11
Ray version: 2.20.0
OS: Ubuntu LTS
To run the code below: python file_name.py --env Alien --gpu 0 --trials 1

import argparse
import json
import os
import pathlib
import time
import uuid

import gymnasium as gym
import numpy as np
import pandas as pd
import torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.tune.registry import register_env
from torch import nn
from tqdm.rich import tqdm


def make_atari(env_config):
    env = gym.make(env_config["name"])
    env = wrap_deepmind(env, 84, True, True)
    return env


def linear_schedule(lr, n_iterations, iteration_steps):
    ts_lr = []
    ts = 0
    for iteration in range(1, n_iterations + 1):
        frac = 1.0 - (iteration - 1.0) / n_iterations
        ts_lr.append((ts, frac * lr))
        ts += iteration_steps
    return ts_lr


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(TorchModelV2, nn.Module):
    def __init__(
        self, observation_space, action_space, num_outputs, model_config, name
    ):
        TorchModelV2.__init__(
            self, observation_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, num_outputs), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)
        self.output = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"] / 255.0
        x = torch.reshape(x, (-1, 4, 84, 84))
        self.output = self.network(x)
        return self.actor(self.output), []

    @override(ModelV2)
    def value_function(self):
        assert self.output is not None, "must call forward first!"
        return torch.reshape(self.critic(self.output), [-1])


def train_atari(args):
    total_timesteps = int(10e6)
    lr = 2.5e-4
    n_envs = 8
    n_steps = 128
    n_iterations = total_timesteps // (n_envs * n_steps)
    lr_schedule = linear_schedule(lr, n_iterations, n_steps * n_envs)

    ModelCatalog.register_custom_model("Agent", Agent)
    register_env(f"{args.env}", make_atari)

    ppo = (
        PPOConfig()
        .training(
            gamma=0.99,
            grad_clip_by="global_norm",
            train_batch_size=128 * 8,
            model={"custom_model": "Agent"},
            optimizer={"eps": 1e-5},
            lr_schedule=lr_schedule,
            use_critic=True,
            use_gae=True,
            lambda_=0.95,
            use_kl_loss=False,
            kl_coeff=0.0,  # not used
            kl_target=0.01,  # not used
            sgd_minibatch_size=256,
            num_sgd_iter=4,
            shuffle_sequences=True,
            vf_loss_coeff=0.5,
            entropy_coeff=0.01,
            entropy_coeff_schedule=None,
            clip_param=0.1,
            vf_clip_param=0.1,
            grad_clip=0.5,
        )
        .environment(
            env=f"{args.env}",
            env_config={"name": f"{args.env}NoFrameskip-v4"},
            render_env=False,
            clip_rewards=True,
            normalize_actions=False,
            clip_actions=False,
            is_atari=True,
        )
        .env_runners(
            num_env_runners=8,
            num_envs_per_env_runner=1,
            rollout_fragment_length=128,
            batch_mode="truncate_episodes",
            explore=True,
            exploration_config={"type": "StochasticSampling"},
            create_env_on_local_worker=False,
            preprocessor_pref=None,
            observation_filter="NoFilter",
        )
        .framework(framework="torch")
        .evaluation(
            evaluation_interval=None,
            evaluation_duration=100,
            evaluation_duration_unit="episodes",
            evaluation_config={
                "explore": True,
                "exploration_config": {"type": "StochasticSampling"},
            },
            evaluation_num_env_runners=1,
            always_attach_evaluation_results=True,
        )
        .debugging(seed=args.seed)
        .resources(
            num_gpus=0.3,
            num_cpus_per_worker=1,
            num_gpus_per_worker=0,
            num_cpus_for_local_worker=1,
        )
        .reporting(
            metrics_num_episodes_for_smoothing=100,
            min_train_timesteps_per_iteration=128 * 8,
            min_sample_timesteps_per_iteration=128 * 8,
        )
        .experimental(_disable_preprocessor_api=True, _enable_new_api_stack=False)
        .build()
    )

    # train
    start_time = time.time()
    progress_data = {"global_step": [], "mean_reward": []}
    for iteration in tqdm(range(1, n_iterations + 1)):
        result = ppo.train()
        rewards = result["env_runner_results"]["hist_stats"]["episode_reward"]
        global_step = result["timesteps_total"]
        if len(rewards) > 100:
            rewards = rewards[-100:]
        mean_reward = np.nan if len(rewards) == 0 else float(np.mean(rewards))
        progress_data["global_step"].append(global_step)
        progress_data["mean_reward"].append(mean_reward)
    train_end_time = time.time()
    progress_df = pd.DataFrame(progress_data)
    progress_df.to_csv(os.path.join(args.path, "progress.csv"), index=False)

    # eval
    # there seems to be an issue where rllib does not follow specified evaluation episodes of 100
    # here we run evaluation until 100 episodes, store the final eval results as well as initial eval results
    initial_results = []
    results = []
    count = 0
    while len(results) < 100:
        result = ppo.evaluate()
        assert result["env_runner_results"]["episodes_this_iter"] == len(
            result["env_runner_results"]["hist_stats"]["episode_reward"]
        )
        results += result["env_runner_results"]["hist_stats"]["episode_reward"]
        if count == 0:
            initial_results += result["env_runner_results"]["hist_stats"][
                "episode_reward"
            ]
        count += 1
    eval_end_time = time.time()
    args.training_time_h = ((train_end_time - start_time) / 60) / 60
    args.total_time_h = ((eval_end_time - start_time) / 60) / 60
    args.eval_mean_reward = float(np.mean(results[:100]))
    args.initial_eval_mean_reward = float(np.mean(initial_results))
    args.initial_eval_episodes = len(initial_results)

    ppo.save(args.path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        help="Specify GPU index",
        default=0,
    )
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        help="Specify Atari environment w/o version",
        default="Pong",
    )
    parser.add_argument(
        "-t",
        "--trials",
        type=int,
        help="Specify number of trials",
        default=5,
    )
    args = parser.parse_args()
    for _ in range(args.trials):
        args.id = uuid.uuid4().hex
        args.path = os.path.join("trials", "ppo", args.env, args.id)
        args.seed = int(time.time())

        # create dir
        pathlib.Path(args.path).mkdir(parents=True, exist_ok=True)

        # set gpu
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"

        train_atari(args)

        # save trial info
        with open(os.path.join(args.path, "info.json"), "w") as f:
            json.dump(vars(args), f, indent=4)

RLlib PPO Pong Mean Reward Graph

RLlib PPO Alien Mean Reward Graph

Other Issues Found
On a side note, I also found 2 other issues.

Firstly, setting preprocessor_pref="deepmind" in .env_runners does not seem to work at all and thus I have to manually configure the environment via a custom function make_atari and attach the wrappers there. To test this, use the following code below. The modifications made with respect to the above code is that now I am not using a custom function with the required wrappers and instead using preprocessor_pref="deepmind".

RLlib Code
Python version: 3.11
Ray version: 2.20.0
OS: Ubuntu LTS
To run the code below: python file_name.py --env Alien --gpu 0 --trials 1

import argparse
import json
import os
import pathlib
import time
import uuid

import gymnasium as gym
import numpy as np
import pandas as pd
import torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.tune.registry import register_env
from torch import nn
from tqdm.rich import tqdm


def make_atari(env_config):
    env = gym.make(env_config["name"])
    env = wrap_deepmind(env, 84, True, True)
    return env


def linear_schedule(lr, n_iterations, iteration_steps):
    ts_lr = []
    ts = 0
    for iteration in range(1, n_iterations + 1):
        frac = 1.0 - (iteration - 1.0) / n_iterations
        ts_lr.append((ts, frac * lr))
        ts += iteration_steps
    return ts_lr


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(TorchModelV2, nn.Module):
    def __init__(
        self, observation_space, action_space, num_outputs, model_config, name
    ):
        TorchModelV2.__init__(
            self, observation_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, num_outputs), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)
        self.output = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"] / 255.0
        x = torch.reshape(x, (-1, 4, 84, 84))
        self.output = self.network(x)
        return self.actor(self.output), []

    @override(ModelV2)
    def value_function(self):
        assert self.output is not None, "must call forward first!"
        return torch.reshape(self.critic(self.output), [-1])


def train_atari(args):
    total_timesteps = int(10e6)
    lr = 2.5e-4
    n_envs = 8
    n_steps = 128
    n_iterations = total_timesteps // (n_envs * n_steps)
    lr_schedule = linear_schedule(lr, n_iterations, n_steps * n_envs)

    ModelCatalog.register_custom_model("Agent", Agent)
    # register_env(f"{args.env}", make_atari)

    ppo = (
        PPOConfig()
        .training(
            gamma=0.99,
            grad_clip_by="global_norm",
            train_batch_size=128 * 8,
            model={"custom_model": "Agent"},
            optimizer={"eps": 1e-5},
            lr_schedule=lr_schedule,
            use_critic=True,
            use_gae=True,
            lambda_=0.95,
            use_kl_loss=False,
            kl_coeff=0.0,  # not used
            kl_target=0.01,  # not used
            sgd_minibatch_size=256,
            num_sgd_iter=4,
            shuffle_sequences=True,
            vf_loss_coeff=0.5,
            entropy_coeff=0.01,
            entropy_coeff_schedule=None,
            clip_param=0.1,
            vf_clip_param=0.1,
            grad_clip=0.5,
        )
        .environment(
            env=f"{args.env}NoFrameskip-v4"},
            render_env=False,
            clip_rewards=True,
            normalize_actions=False,
            clip_actions=False,
            is_atari=True,
        )
        .env_runners(
            num_env_runners=8,
            num_envs_per_env_runner=1,
            rollout_fragment_length=128,
            batch_mode="truncate_episodes",
            explore=True,
            exploration_config={"type": "StochasticSampling"},
            create_env_on_local_worker=False,
            preprocessor_pref="deepmind",
            observation_filter="NoFilter",
        )
        .framework(framework="torch")
        .evaluation(
            evaluation_interval=None,
            evaluation_duration=100,
            evaluation_duration_unit="episodes",
            evaluation_config={
                "explore": True,
                "exploration_config": {"type": "StochasticSampling"},
            },
            evaluation_num_env_runners=1,
            always_attach_evaluation_results=True,
        )
        .debugging(seed=args.seed)
        .resources(
            num_gpus=0.3,
            num_cpus_per_worker=1,
            num_gpus_per_worker=0,
            num_cpus_for_local_worker=1,
        )
        .reporting(
            metrics_num_episodes_for_smoothing=100,
            min_train_timesteps_per_iteration=128 * 8,
            min_sample_timesteps_per_iteration=128 * 8,
        )
        .experimental(_disable_preprocessor_api=False, _enable_new_api_stack=False)
        .build()
    )

    # train
    start_time = time.time()
    progress_data = {"global_step": [], "mean_reward": []}
    for iteration in tqdm(range(1, n_iterations + 1)):
        result = ppo.train()
        rewards = result["env_runner_results"]["hist_stats"]["episode_reward"]
        global_step = result["timesteps_total"]
        if len(rewards) > 100:
            rewards = rewards[-100:]
        mean_reward = np.nan if len(rewards) == 0 else float(np.mean(rewards))
        progress_data["global_step"].append(global_step)
        progress_data["mean_reward"].append(mean_reward)
    train_end_time = time.time()
    progress_df = pd.DataFrame(progress_data)
    progress_df.to_csv(os.path.join(args.path, "progress.csv"), index=False)

    # eval
    # there seems to be an issue where rllib does not follow specified evaluation episodes of 100
    # here we run evaluation until 100 episodes, store the final eval results as well as initial eval results
    initial_results = []
    results = []
    count = 0
    while len(results) < 100:
        result = ppo.evaluate()
        assert result["env_runner_results"]["episodes_this_iter"] == len(
            result["env_runner_results"]["hist_stats"]["episode_reward"]
        )
        results += result["env_runner_results"]["hist_stats"]["episode_reward"]
        if count == 0:
            initial_results += result["env_runner_results"]["hist_stats"][
                "episode_reward"
            ]
        count += 1
    eval_end_time = time.time()
    args.training_time_h = ((train_end_time - start_time) / 60) / 60
    args.total_time_h = ((eval_end_time - start_time) / 60) / 60
    args.eval_mean_reward = float(np.mean(results[:100]))
    args.initial_eval_mean_reward = float(np.mean(initial_results))
    args.initial_eval_episodes = len(initial_results)

    ppo.save(args.path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        help="Specify GPU index",
        default=0,
    )
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        help="Specify Atari environment w/o version",
        default="Pong",
    )
    parser.add_argument(
        "-t",
        "--trials",
        type=int,
        help="Specify number of trials",
        default=5,
    )
    args = parser.parse_args()
    for _ in range(args.trials):
        args.id = uuid.uuid4().hex
        args.path = os.path.join("trials", "ppo", args.env, args.id)
        args.seed = int(time.time())

        # create dir
        pathlib.Path(args.path).mkdir(parents=True, exist_ok=True)

        # set gpu
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"

        train_atari(args)

        # save trial info
        with open(os.path.join(args.path, "info.json"), "w") as f:
            json.dump(vars(args), f, indent=4)

Secondly, while evaluating, specifying the number of episodes in .evaluation seems to have no effect at all. The evaluation function seems to return evaluation over a different number of episodes than what was initially specified in .evaluation. This is why there is a while loop towards the end of the codes given above, in order to accumulate the correct number of episodes. You can easily observe this in the info.json file generated after running the file and looking at the initial_eval_episodes field.