[RLlib] GPU Memory Leak? Tune + PPO, Policy Server + Client

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi all,

I am using rrllib with tune + ppo to try and train an agent, but it seems like as epochs complete, the gpu ram does not get emptied out and slowly builds up?

I have recorded this here: Ray GPU Memory Leak? - YouTube

Relevant timestamps:
04:19 - First Epoch + GRam Spike
07:12 - Second Epoch + GRam Spike
10:25 - Third Epoch + NO SPIKE
13:41 - Fourth Epoch + GRam Spike
17:26 - Fifth Epoch + GRam Spike
19:54 - Sixth Epoch + GRam Spike

Subsequent epochs nothing happens, so I stop recording, and on the ~12th epoch it spikes again and there’s a ton of memory thrashing as the epoch fails to complete.

Sample server code:

import ray
from ray.rllib.env import PolicyServerInput
from ray.rllib.algorithms.ppo import PPOConfig

import numpy as np
import argparse
from gymnasium.spaces import MultiDiscrete, Box

ray.init(object_store_memory=40 * (10 ** 9), num_cpus=6, num_gpus=1, log_to_driver=False )

ppo_config = PPOConfig()

def _input(ioctx):
    # We are remote worker, or we are local worker with num_workers=0:
    # Create a PolicyServerInput.
    if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
        return PolicyServerInput(
            ioctx,
           'localhost',
            55556 + ioctx.worker_index - (1 if ioctx.worker_index > 0 else 0),
        )
    # No InputReader (PolicyServerInput) needed.
    else:
        return None


x = 320
y = 240

ppo_config.clip_param = 0.175
ppo_config.gamma = 0.996  # default 0.99 -> how far into the future to care for rewards

ppo_config.lambda_ = 0.99
ppo_config.kl_target = 0.01  # default 0.01
ppo_config.rollout_fragment_length = 64
ppo_config.train_batch_size = 4500
ppo_config.sgd_minibatch_size = 512
ppo_config.num_sgd_iter = 1
ppo_config.lr = 9e-5

ppo_config.model = {
    # Share layers for value function. If you set this to True, it's
    # important to tune vf_loss_coeff.
    "vf_share_layers": True,

    'use_attention': True,
    "max_seq_len": 50,
    "attention_num_transformer_units": 1,
    "attention_dim": 256,
    "attention_memory_inference": 50,
    "attention_memory_training": 50,
    "attention_num_heads": 8,
    "attention_head_dim": 32,
    "attention_position_wise_mlp_dim": 128,
    "attention_init_gru_gate_bias": 2.0,

    "conv_filters": [
        [32, [12, 16], [7, 9]],
        [128, [6, 6], 4],
        [256, [9, 9], 1]
    ],
    "conv_activation": "relu"
}
ppo_config.batch_mode = "complete_episodes"
ppo_config.simple_optimizer = False


ppo_config.env = None
ppo_config.observation_space = Box(low=0, high=1, shape=(y, x, 1), dtype=np.float32)
ppo_config.action_space = MultiDiscrete(
    [
        2,  # W
        2,  # A
        2,  # S
        2,  # D
        2,  # Space
        2,  # H
        2,  # J
        2,  # K
        2  # L
    ]
)
ppo_config.env_config = {
    "sleep": True,
    'replayOn': False
}

ppo_config.rollouts(num_rollout_workers=2, enable_connectors=False)
ppo_config.offline_data(input_=_input)

ppo_config.framework_str = 'torch'
ppo_config.log_sys_usage = False
ppo_config.compress_observations = True
ppo_config.shuffle_sequences = False

ppo_config.num_gpus = 0.5
ppo_config.num_cpus_for_local_worker = 4

ppo_config.num_cpus_per_worker = 1

tempyy = ppo_config.to_dict()

print(tempyy)

from ray import tune

name = "Checkpoint1"
print(f"Starting: {name}")

tune.run("PPO",
         resume='AUTO',
         config=tempyy,
         name=name,
         keep_checkpoints_num=20, checkpoint_score_attr="episode_reward_mean", mode='max',
         checkpoint_freq=1,
         metric="episode_reward_mean",
         max_failures=10,
         # resume=True,
         # restore="C:\\Users\\denys\\ray_results\\lstmV1_jump0005_batch15360_minibatch_1024_lr9e-5_sgd5_prevAction-False\\PPO_None_3866f_00000_0_2023-04-17_08-48-58\\checkpoint_000055",
         checkpoint_at_end=True)

Please let me know if there is something I’m configuring wrong/misunderstanding or a genuine bug!

@Denys_Ashikhin We need to know more info on this to be able to help. Policy server / Client are somewhat in “keep the lights on” mode in terms of priorities for the ray team. Can you reproduce the GPU leak with some common environments like Atari or something else that the community can help you with?

Also in your code snippet I see a lot of red flags on how you setup the configuration (Not saying that those are causing this particular issue but more as to what is considered as good practice). You should NEVER set the config attributes directly and should instead use the public APIs that are provided on the AlgorithmConfig object. for example look at how we set the configuration up in this example:

And here is the complete API doc for AlgorithmConfig

https://docs.ray.io/en/master/rllib/package_ref/doc/ray.rllib.algorithms.algorithm_config.AlgorithmConfig.html#ray.rllib.algorithms.algorithm_config.AlgorithmConfig

Thanks for that, I’ll switch over to the other api. As for more common - I’ll try to recreate it using the vizdoom built-in env. Is that acceptable? Reason being it should be same dimensions as my env and roughly same conv filters.

Edit:

 # TODO (Kourosh): Enable when LSTMs are supported.
_enable_learner_api=False,
```
I couldn't find a good explanation of the above key?

@kourosh
Following up, I wasn’t able to get vizdoom setup since it requires external installs. Instead I have a repro using the built-in random env.

Policy Server:

import ray
from ray.rllib.env import PolicyServerInput
from ray.rllib.algorithms.ppo import PPOConfig
import numpy as np
import argparse
from gymnasium.spaces import MultiDiscrete, Box

ray.init(object_store_memory=40 * (10 ** 9), num_cpus=6, num_gpus=1, log_to_driver=False,
         # configure_logging=True,
         # logging_level=3
         )

ppo_config = PPOConfig()

parser = argparse.ArgumentParser(description='Optional app description')
parser.add_argument('-ip', type=str, help='IP of this device')

parser.add_argument('-checkpoint', type=str, help='location of checkpoint to restore from')

args = parser.parse_args()


def _input(ioctx):
    # We are remote worker, or we are local worker with num_workers=0:
    # Create a PolicyServerInput.
    if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
        return PolicyServerInput(
            ioctx,
            args.ip,
            55556 + ioctx.worker_index - (1 if ioctx.worker_index > 0 else 0),
        )
    # No InputReader (PolicyServerInput) needed.
    else:
        return None


x = 320
y = 240

proper_config = (
    PPOConfig()
    .environment(
        None,
        action_space=MultiDiscrete(
            [
                2,  # W
                2,  # A
                2,  # S
                2,  # D
                2,  # Space
                2,  # H
                2,  # J
                2,  # K
                2  # L
            ]
        ),
        observation_space=Box(low=0, high=1, shape=(y, x, 1), dtype=np.float32)
    )
    .training(
        clip_param=0.175,
        gamma=0.996,
        lambda_=0.99,
        kl_target=0.1,
        train_batch_size=4500,
        sgd_minibatch_size=512,
        num_sgd_iter=1,
        lr=9e-5,
        shuffle_sequences=False,
        model={
            "vf_share_layers": True,
            'use_attention': True,
            "max_seq_len": 50,
            "attention_num_transformer_units": 1,
            "attention_dim": 256,
            "attention_memory_inference": 50,
            "attention_memory_training": 50,
            "attention_num_heads": 8,
            "attention_head_dim": 32,
            "attention_position_wise_mlp_dim": 128,
            # "attention_use_n_prev_actions": 2,
            # "attention_use_n_prev_rewards": 64,
            "attention_init_gru_gate_bias": 2.0,

            "conv_filters": [
                # 240 X 320
                [32, [12, 16], [7, 9]],
                [128, [6, 6], 4],
                [256, [9, 9], 1]
            ],
            "conv_activation": "relu",
        }
    )
    .exploration(
        explore=True
    )
    .framework(
        framework='torch'
    )
    .offline_data(
        input_=_input
    )
    .resources(
        num_gpus=0.65,
        num_cpus_per_worker=1,
        num_cpus_per_learner_worker=4
    )
    .rollouts(
        rollout_fragment_length=64,
        num_rollout_workers=2,
        enable_connectors=False,
        batch_mode='complete_episodes',
        compress_observations=True
    )
)

from ray import tune

name = "" + args.checkpoint
print(f"Starting: {name}")

tune.run("PPO",
         resume='AUTO',
         config=proper_config.to_dict(),
         name=name,
         keep_checkpoints_num=20, checkpoint_score_attr="episode_reward_mean", mode='max',
         checkpoint_freq=1,
         metric="episode_reward_mean",
         max_failures=10,
         checkpoint_at_end=True)

Policy Client

from ray.rllib.env.policy_client import PolicyClient
from ray.rllib.examples.env.random_env import RandomEnv
from gymnasium.spaces import MultiDiscrete, Box
import numpy as np
import time

if __name__ == "__main__":

    env = RandomEnv(
        config={
            "action_space": MultiDiscrete(
                [
                    2,  # W
                    2,  # A
                    2,  # S
                    2,  # D
                    2,  # Space
                    2,  # H
                    2,  # J
                    2,  # K
                    2  # L
                ], ),
            "observation_space": Box(low=0, high=1, shape=(240, 320, 1), dtype=np.float32),
            "p_terminated": 1/350,
            "max_episode_len": 650,
            "check_action_bounds": True
        }
    )
    client = PolicyClient(
        f"http://localhost:{55556}", inference_mode='remote'
    )

    # Start a new episode.
    obs, info = env.reset()
    eid = client.start_episode()

    rewards = 0.0

    actions = 0
    time_start = time.time()
    while True:

        action = client.get_action(eid, obs)
        actions+=1
        if (time.time() - time_start) > 1:

            print(f"average actions: ${actions / (time.time() - time_start)}")
            time_start = time.time()
            actions = 0

        # Perform a step in the external simulator (env).
        obs, reward, terminated, truncated, info = env.step(action)
        rewards += reward

        # Log next-obs, rewards, and infos.
        client.log_returns(eid, reward, info=info)

        # Reset the episode if done.
        if terminated or truncated:
            print("Total reward:", rewards)

            actions = 0
            time_start = time.time()

            rewards = 0.0

            # End the old episode.
            client.end_episode(eid, obs)

            # Start a new episode.
            obs, info = env.reset()
            eid = client.start_episode()

I can confirm that the same thing happens where my GPU Vram keeps being maxed out in increments after each epoch.
Moreover, the amount it increases is DIRECTLY tied to the train_batch_size, setting it a bigger or smaller number, increases or decreases the jumps linearly.

Please let me know if you are able to replicate.

Edit:
I’ve also tested with:
num_workers:0,
compress_observations:false,
batch_mode: default

What happens if we don’t use the PolicyServer / Client APIs (by just passing the env directly to the algorithm)?

Also am wondering if this would solve the issue?

I actually have this implemented on my local, so unfortunately that isn’t it. However, I’m not sure what you mean by passing it directly? Could you please provide an example or just the code tweak necessary

@kourosh
I think you mean something like this?

import ray
from ray.rllib.env import PolicyServerInput
from ray.rllib.algorithms.ppo import PPOConfig
import numpy as np
import argparse
from gymnasium.spaces import MultiDiscrete, Box

ray.init(object_store_memory=40 * (10 ** 9), num_cpus=6, num_gpus=1, log_to_driver=False,
         # configure_logging=True,
         # logging_level=3
         )

ppo_config = PPOConfig()

parser = argparse.ArgumentParser(description='Optional app description')
parser.add_argument('-ip', type=str, help='IP of this device')

parser.add_argument('-checkpoint', type=str, help='location of checkpoint to restore from')

args = parser.parse_args()


def _input(ioctx):
    # We are remote worker, or we are local worker with num_workers=0:
    # Create a PolicyServerInput.
    if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
        return PolicyServerInput(
            ioctx,
            args.ip,
            55556 + ioctx.worker_index - (1 if ioctx.worker_index > 0 else 0),
        )
    # No InputReader (PolicyServerInput) needed.
    else:
        return None


x = 320
y = 240





proper_config = (
    PPOConfig()
    .environment(
        "ray.rllib.examples.env.random_env.RandomEnv",
        env_config={
            "action_space": MultiDiscrete(
                [
                    2,  # W
                    2,  # A
                    2,  # S
                    2,  # D
                    2,  # Space
                    2,  # H
                    2,  # J
                    2,  # K
                    2  # L
                ], ),
            "observation_space": Box(low=0, high=1, shape=(240, 320, 1), dtype=np.float32),
            "p_terminated": 1/350,
            "max_episode_len": 650,
            "check_action_bounds": True
        }
    )
    .training(
        clip_param=0.175,
        gamma=0.996,
        lambda_=0.99,
        kl_target=0.1,
        train_batch_size=4500,
        sgd_minibatch_size=512,
        num_sgd_iter=1,
        lr=9e-5,
        shuffle_sequences=False,
        model={
            "vf_share_layers": True,
            'use_attention': True,
            "max_seq_len": 50,
            "attention_num_transformer_units": 1,
            "attention_dim": 256,
            "attention_memory_inference": 50,
            "attention_memory_training": 50,
            "attention_num_heads": 8,
            "attention_head_dim": 32,
            "attention_position_wise_mlp_dim": 128,
            # "attention_use_n_prev_actions": 2,
            # "attention_use_n_prev_rewards": 64,
            "attention_init_gru_gate_bias": 2.0,

            "conv_filters": [
                # 240 X 320
                [32, [12, 16], [7, 9]],
                [128, [6, 6], 4],
                [256, [9, 9], 1]
            ],
            "conv_activation": "relu",
        }
    )
    .exploration(
        explore=True
    )
    .framework(
        framework='torch'
    )
    .resources(
        num_gpus=0.65,
        num_cpus_per_worker=1,
        num_cpus_per_learner_worker=4
    )
    .rollouts(
        rollout_fragment_length=64,
        num_rollout_workers=2,
        enable_connectors=False,
        batch_mode='complete_episodes',
        compress_observations=True
    )
    .rl_module(_enable_rl_module_api=False)
)

from ray import tune

name = "" + args.checkpoint
print(f"Starting: {name}")

tune.run("PPO",
         resume='AUTO',
         config=proper_config.to_dict(),
         name=name,
         keep_checkpoints_num=20, checkpoint_score_attr="episode_reward_mean", mode='max',
         checkpoint_freq=1,
         metric="episode_reward_mean",
         max_failures=10,
         checkpoint_at_end=True)

Except it’s not working for me and instead throwing:

ray::PPO.__init__() (pid=27964, ip=127.0.0.1, repr=PPO)
  File "python\ray\_raylet.pyx", line 868, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 919, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 875, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 879, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 819, in ray._raylet.execute_task.function_executor
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\_private\function_manager.py", line 674, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 466, in __init__
    super().__init__(
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\tune\trainable\trainable.py", line 169, in __init__
    self.setup(copy.deepcopy(self.config))
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 592, in setup
    self.workers = WorkerSet(
  File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 194, in __init__
    raise e.args[0].args[2]
TypeError: RandomEnv.__init__() got an unexpected keyword argument 'observation_space'


Process finished with exit code -1

Hi @Denys_Ashikhin and @kourosh,

Here are two posts that might be relevant here.

Hi @mannyv ,

Long time no see, I took a look at both of the posts. As for checking queue size, I do have the pr manually added for a different memory leak in policy server, so the relevant code to check queue size:

    @override(InputReader)
    def next(self):
        # Blocking wait until there is something in the deque.
        while len(self.samples_queue) == 0:
            time.sleep(0.1)

        print(f"Size of samples queue is: {len(self.samples_queue)} " )
        # Utilize last items first in order to remain as closely as possible
        # to operating on-policy.
        return self.samples_queue.pop()
        # return self.samples_queue.get()

However, it always prints as 1, so I don’t think it’s unable to keep up. Moreover, if you look at the video to demonstrate the issue, the GRAM spikes during training (batches are loaded onto GPU, makes sense) but it never decreases. I am 99% sure it’s something with GC post epoch not cleaning up or holding onto a reference somewhere.

As for the github, the person mentioned tf worked when tf2 didn’t. I tried both tf and tf2 and got the following error:

(PPO pid=18168) AssertionError: Expect all shape elements to be an integer, actual type: (<class 'tensorflow.python.framework.tensor_shape.Dimension'>,)
(RolloutWorker pid=12956) 2023-05-24 23:48:43,716	ERROR worker.py:825 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=12956, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001EC1CFAF2B0>)
(RolloutWorker pid=12956)   File "python\ray\_raylet.pyx", line 875, in ray._raylet.execute_task
(RolloutWorker pid=12956)   File "python\ray\_raylet.pyx", line 879, in ray._raylet.execute_task
(RolloutWorker pid=12956)   File "python\ray\_raylet.pyx", line 819, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\_private\function_manager.py", line 674, in actor_method_executor
(RolloutWorker pid=12956)     return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
(RolloutWorker pid=12956)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 738, in __init__
(RolloutWorker pid=12956)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
(RolloutWorker pid=12956)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1985, in _update_policy_map
(RolloutWorker pid=12956)     self._build_policy_map(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
(RolloutWorker pid=12956)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 2097, in _build_policy_map
(RolloutWorker pid=12956)     new_policy = create_policy_for_framework(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\utils\policy.py", line 139, in create_policy_for_framework
(RolloutWorker pid=12956)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\algorithms\ppo\ppo_tf_policy.py", line 81, in __init__
(RolloutWorker pid=12956)     base.__init__(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 116, in __init__
(RolloutWorker pid=12956)     self.model = self.make_model()
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 235, in make_model
(RolloutWorker pid=12956)     return ModelCatalog.get_model_v2(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\models\catalog.py", line 685, in get_model_v2
(RolloutWorker pid=12956)     return wrapper(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\ray\rllib\models\tf\attention_net.py", line 416, in __init__
(RolloutWorker pid=12956)     in_space = gym.spaces.Box(
(RolloutWorker pid=12956)   File "C:\personal\ai\ray_venv\lib\site-packages\gymnasium\spaces\box.py", line 90, in __init__
(RolloutWorker pid=12956)     assert all(
(RolloutWorker pid=12956) AssertionError: Expect all shape elements to be an integer, actual type: (<class 'tensorflow.python.framework.tensor_shape.Dimension'>,)
(PPO pid=18168) 
(RolloutWorker pid=12956) 
(RolloutWorker pid=22800) 

Process finished with exit code 1

So again, not sure what’s going on there.
I still suspect it’s something to do with the training loop not releasing resources - and chance you can replicate an take a look? Or at least give me an idea where to investigate the training loop stuff?

@mannyv @kourosh
I want to rule out that it’s anything not related to tune or the underlying api - so could you point me in the right direction to setup this random environment to train without the policy client/server?

I get a sneaking suspicion there might be something else going on deeper due to the large observation sizes

yes. So you should pass in the env directly. Something like this:

from tune import registry
import ray

registry.register_env("random_env", ray.rllib.examples.env.random_env.RandomEnv)

config = (
	PPOConfig()
	.environment(
        "random_env"
        env_config={
            "action_space": MultiDiscrete(
                [
                    2,  # W
                    2,  # A
                    2,  # S
                    2,  # D
                    2,  # Space
                    2,  # H
                    2,  # J
                    2,  # K
                    2  # L
                ], ),
            "observation_space": Box(low=0, high=1, shape=(240, 320, 1), dtype=np.float32),
            "p_terminated": 1/350,
            "max_episode_len": 650,
            "check_action_bounds": True
        }
    )
    .training(
        clip_param=0.175,
        gamma=0.996,
        lambda_=0.99,
        kl_target=0.1,
        train_batch_size=4500,
        sgd_minibatch_size=512,
        num_sgd_iter=1,
        lr=9e-5,
        shuffle_sequences=False,
        model={
            "vf_share_layers": True,
            'use_attention': True,
            "max_seq_len": 50,
            "attention_num_transformer_units": 1,
            "attention_dim": 256,
            "attention_memory_inference": 50,
            "attention_memory_training": 50,
            "attention_num_heads": 8,
            "attention_head_dim": 32,
            "attention_position_wise_mlp_dim": 128,
            # "attention_use_n_prev_actions": 2,
            # "attention_use_n_prev_rewards": 64,
            "attention_init_gru_gate_bias": 2.0,

            "conv_filters": [
                # 240 X 320
                [32, [12, 16], [7, 9]],
                [128, [6, 6], 4],
                [256, [9, 9], 1]
            ],
            "conv_activation": "relu",
        }
    )
    .exploration(
        explore=True
    )
    .framework(
        framework='torch'
    )
    .resources(
        num_gpus=0.65,
        num_cpus_per_worker=1,
        num_cpus_per_learner_worker=4
    )
    .rollouts(
        rollout_fragment_length=64,
        num_rollout_workers=2,
        enable_connectors=False,
        batch_mode='complete_episodes',
        compress_observations=True
    )
)

@kourosh
Thanks for that, I tried it as follows:

from ray import tune
from ray.tune import registry

from ray.rllib.examples.env.random_env import RandomEnv
registry.register_env("random_env", RandomEnv)


proper_config = (
    PPOConfig()
    .environment(
        "random_env",
        env_config={
            "action_space": MultiDiscrete(
                [
                    2,  # W
                    2,  # A
                    2,  # S
                    2,  # D
                    2,  # Space
                    2,  # H
                    2,  # J
                    2,  # K
                    2  # L
                ], ),
            "observation_space": Box(low=0, high=1, shape=(240, 320, 1), dtype=np.float32),
            "p_terminated": 1 / 350,
            "max_episode_len": 650,
            "check_action_bounds": True
        }
    )
    .training(
        clip_param=0.175,
        gamma=0.996,
        lambda_=0.99,
        kl_target=0.1,
        train_batch_size=4500,
        sgd_minibatch_size=512,
        num_sgd_iter=1,
        lr=9e-5,
        shuffle_sequences=False,
        model={
            "vf_share_layers": True,
            'use_attention': True,
            "max_seq_len": 50,
            "attention_num_transformer_units": 1,
            "attention_dim": 256,
            "attention_memory_inference": 50,
            "attention_memory_training": 50,
            "attention_num_heads": 8,
            "attention_head_dim": 32,
            "attention_position_wise_mlp_dim": 128,
            # "attention_use_n_prev_actions": 2,
            # "attention_use_n_prev_rewards": 64,
            "attention_init_gru_gate_bias": 2.0,

            "conv_filters": [
                # 240 X 320
                [32, [12, 16], [7, 9]],
                [128, [6, 6], 4],
                [256, [9, 9], 1]
            ],
            "conv_activation": "relu",
        }
    )
    .exploration(
        explore=True
    )
    .framework(
        framework='torch'
    )
    .resources(
        num_gpus=0.65,
        num_cpus_per_worker=1,
        num_cpus_per_learner_worker=4
    )
    .rollouts(
        rollout_fragment_length=64,
        num_rollout_workers=2,
        enable_connectors=False,
        batch_mode='complete_episodes',
        compress_observations=True
    )
    # .rl_module(_enable_rl_module_api=False)
)


tune.run("PPO",
         resume='AUTO',
         config=proper_config.to_dict(),
         name=name,
         keep_checkpoints_num=20, checkpoint_score_attr="episode_reward_mean", mode='max',
         checkpoint_freq=1,
         metric="episode_reward_mean",
         max_failures=10,
         checkpoint_at_end=True)

And can confirm the same leak occurs - so the good news it’s not policy client/server related. Bad news it’s deeper in.

Can you confirm if the same is happening on your end?
@mannyv I would greatly appreciate if you could also try the above script and let me know if you see the same spiking GRAM after (either every, or every other) epoch?

ok so, I would recommend trimming down the configs even further to understand the cause of the leak. There could be two possibilities here:

  1. There is a fundamental bug within rllib that was not discovered so far
  2. It could be observation space related which either could be wrong usage of something or again a bug for certain class of spaces.

If it’s a bug. you can share the trimmed down version of repro code on a github issue and we will triage it and respond as soon as possible. There is a chance that in the process you may find the solution and in that case your contribution is ofc. welcomed.

We recently also shared a high level debugging guide. I hope that helps in narrowing down the issue.

Also you should do this without tune.

To do that use something like the following:

algo = config.build()

for n_iter in range(max_iter):
    results = algo.train()
    if n_iter % freq == 0:
           pprint(results)

@kourosh
I tried without tune as you showed and can confirm the same issue is still happening.
I will try to trim down the settings to see if I can find a more root cause of this. I will report back asap if I find anything

@kourosh
I was able to narrow it down to a specific case of observation space.
I tried all the filters from: ray/utils.py at e92e554db729be888325f7bb8c3caada182a1159 · ray-project/ray · GitHub

and all of them worked perfectly fine, except for vizdoom variants. In particular they had CNN filter that didn’t have the same X and Y.
In the case of 240 x 320 it would lead to those gpu memory spikes.
In the case of 320 x 640it took a long time to complete an epoch and jumped my GPU memory usage from 4gb to 19gb after 1 epoch. (using rtx4090 + 96gb of ram) → Additionally, it capped out my ram and thrashed my OS ssd (pagefile) the entire time till finally finishing.

Should I close this issue and open a new one that is targeted specifically to address these findings?

Yes. I think the issue is now isolated enough to open a new thread. I honestly don’t know what the cause is at this point. But I would set some breakpoints to see what is causing the residual accumulation of GPU memory.

I’ll close this topic, and open a different on the forums if the github ([RLlib] PPO Memory Leak on Uneven CNN (conv) filters · Issue #35866 · ray-project/ray · GitHub) gets stale.

One final point, setting simple_optomizer: True seems to prevent this issue (which leads to much slower training epochs however, like 16x).