KeyError: 'advantages' when training PPO with custom model in RLlib

  • High: It blocks me to complete my task.

I’m encountering an error when training a PPO agent using RLlib with a custom model. The error occurs during the training step and is related to a missing 'advantages' key. I’ve enabled Generalized Advantage Estimation (GAE) in the configuration, but my custom model does not seem to output the value function predictions, which could be the root cause.

Here’s the full traceback of the error:

Failure # 1 (occurred at 2025-02-21_22-06-18)
e[36mray::PPO.train()e[39m (pid=14196, ip=127.0.0.1, actor_id=b476ab0c55633f76048e9fd401000000, repr=PPO(env=my_ShuttleGridEnv; env-runners=4; learners=1; multi-agent=False))
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 999, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3350, in _run_one_training_iteration
training_step_return_value = self.training_step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 428, in training_step
learner_results = self.learner_group.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 327, in update_from_episodes
return self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 601, in _update
results = self._get_results(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 615, in _get_results
raise result_or_error
ray.exceptions.RayTaskError(KeyError): e[36mray::_WrappedExecutable.apply()e[39m (pid=30612, ip=127.0.0.1, actor_id=72df0ba25bb7dde3f627f47301000000, repr=<ray.train._internal.worker_group._WrappedExecutable object at 0x000002C647C6A580>)
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1641, in apply
return func(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 385, in _learner_update
result = _learner.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1086, in update_from_episodes
self._update_from_batch_or_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1423, in _update_from_batch_or_episodes
fwd_out, loss_per_module, tensor_metrics = self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 497, in _update
return self._possibly_compiled_update(batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 152, in _uncompiled_update
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 924, in compute_losses
loss = self.compute_loss_for_module(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\torch\ppo_torch_learner.py”, line 89, in compute_loss_for_module
batch[Postprocessing.ADVANTAGES] * logp_ratio,
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\policy\sample_batch.py”, line 973, in getitem
value = dict.getitem(self, key)
KeyError: ‘advantages’

Could anyone provide guidance on how to fix this or what changes are needed to make sure the 'advantages' key is present?
my code is shown below:
import sys
from pathlib import Path

Add the project root directory to sys.path

sys.path.append(str(Path(file).parent.parent))

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from model.dummy_catalog import DummyCatalog
from ray.rllib.core import Columns
from ray.rllib.models.torch.torch_distributions import TorchCategorical

#ray.rllib.core.catalog.Catalog

一个简单的卷积网络模块,用于提取局部特征

class CNNModule(nn.Module):
def init(self, input_channels=1, output_dim=128):
super(CNNModule, self).init()
# 第一个卷积层:输出通道数 32,卷积核大小 3x3,保持尺寸不变
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
# 第二个卷积层:输出通道数 64
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
# 下采样层(最大池化),减小尺寸
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# 第三个卷积层:输出通道数 128
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
# 第二个下采样层
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# Dropout 层(防止过拟合)
self.dropout = nn.Dropout(p=0.5)
# 使用自适应平均池化,将空间尺寸缩减为 1x1
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
# 全连接层:输入通道数为 128,输出维度为 output_dim
self.fc = nn.Linear(128, output_dim)

def forward(self, x):
    # x: [BATCH, input_channels, H, W]
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = self.pool1(x)  # 尺寸减半
    x = F.relu(self.bn3(self.conv3(x)))
    x = self.pool2(x)  # 再次下采样
    x = self.global_pool(x)  # 输出 [BATCH, 128, 1, 1]
    x = x.view(x.size(0), -1)  # 展平为 [BATCH, 128]
    x = self.dropout(x)
    x = F.relu(self.fc(x))  # 最终输出 [BATCH, output_dim]
    return x

使用新 API 格式实现 RLModule

class ParametricActionsRLModule(TorchRLModule):
def init(
self,
*,
observation_space, # 环境观测空间
action_space, # 环境动作空间
model_config,
inference_only=False,
learner_only=False,
catalog_class=None
):

    if catalog_class is None:
        catalog_class = DummyCatalog
       
    super().__init__(
        observation_space=observation_space,
        action_space=action_space,
        model_config=model_config,
        inference_only=inference_only,
        learner_only=learner_only,
        catalog_class=catalog_class,
    )
    self.observation_space = observation_space
    self.action_space = action_space
    self.model_config = model_config
    

def setup(self):
    """
    在此函数中构建模型的所有层和子组件。
    """
    self.xy_max = self.model_config.get("xy_max", 20)
    self.action_embed_size = self.model_config.get("action_embed_size", 3)

    # 使用 observation_space 推导输入维度
    input_channels = self.observation_space["real_obs1"].shape[-1]  # 假设是单通道图像
    
    # 使用 model_config 提取隐藏层的维度
    # hidden_dim = self.model_config.get("fcnet_hiddens", 256)#[0]  # 默认值为 256
    hidden_dim = self.model_config.get("fcnet_hiddens", [256])[0]
    
    # 使用 action_space 推导输出维度
    output_dim = self.action_space.n  # 对于离散动作空间,输出是动作数量

    # 定义两个 CNN 模块用于处理环境传入的局部观测
    self.cnn_obs1 = CNNModule(input_channels=input_channels, output_dim=128)
    self.cnn_obs2 = CNNModule(input_channels=input_channels, output_dim=128)

    # 拼接所有特征后映射到动作嵌入空间
    self.combined_fc = nn.Linear(128 + 128 + 2 * self.xy_max, self.action_embed_size)

    # 额外的输出头,处理动作嵌入后的输出
    self._pi_head = nn.Sequential(
        nn.Linear(self.action_embed_size, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, output_dim),
    )
  
    self._vf_head = nn.Sequential(
        nn.Linear(self.action_embed_size, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, 1)
    )

    self.action_dist_cls = TorchCategorical  

def _forward(self, batch, **kwargs):#, **kwargs
“”"
执行前向推理,计算各动作得分。
“”"
input_dict = batch
obs = input_dict[“obs”]

    # 直接使用PyTorch转换,确保数据类型和设备一致
    obs1 = torch.as_tensor(obs["real_obs1"], dtype=torch.float32)
    obs2 = torch.as_tensor(obs["real_obs2"], dtype=torch.float32)
    xy = torch.as_tensor(obs["xy"], dtype=torch.float32)
    action_mask = torch.as_tensor(obs["action_mask"], dtype=torch.float32)

    # if obs1.ndim == 4 and obs1.shape[-1] == 1:
    if obs1.dim() == 4 and obs1.shape[-1] in [1, 3]:  # HWC -> CHW
        obs1 = obs1.permute(0, 3, 1, 2)
    if obs2.ndim == 4 and obs2.shape[-1] in [1, 3]:
        obs2 = obs2.permute(0, 3, 1, 2)

    # 提取局部地图特征
    feat1 = self.cnn_obs1(obs1)
    feat2 = self.cnn_obs2(obs2)

    # 对目标点坐标 xy 进行 one-hot 编码
    x = xy[:, 0].long()
    y = xy[:, 1].long()
    x_onehot = F.one_hot(x, num_classes=self.xy_max)
    y_onehot = F.one_hot(y, num_classes=self.xy_max)
    xy_onehot = torch.cat([x_onehot, y_onehot], dim=1).float()

    # 拼接局部地图特征与目标坐标编码
    combined = torch.cat([feat1, feat2, xy_onehot], dim=1)
    intent_vector = F.relu(self.combined_fc(combined))

    # 获取动作掩码
    action_mask = obs["action_mask"]
    # assert action_mask.min() >= 0 and action_mask.max() <= 1, "Mask值应在0-1之间"
    # 使用 _pi_head 处理意图向量并输出动作
    action_logits = self._pi_head(intent_vector)
    inf_mask = torch.log(action_mask + 1e-10)  # 避免log(0)
    final_logits = action_logits + inf_mask

    # 创建 Categorical 动作分布
    action_dist = self.action_dist_cls(logits=final_logits)
    # 使用 argmax() 选择最大概率的动作(确定性)
    logits = action_dist.logits
    action = torch.argmax(logits, dim=-1)
    action_logp = action_dist.logp(action)
    
    # 计算价值函数预测
    value_input = intent_vector  # 使用与策略网络相同的特征
    vf_preds = self._vf_head(value_input).squeeze(-1)  # 形状变为 [BATCH]
    # 也需要返回一个空的张量(或填充默认值),以防止 KeyError。
    # advantages = torch.zeros_like(vf_preds)  # 默认返回一个与 vf_preds 同形状的零张量
    # 记录当前价值预测供其他方法使用
    self._last_vf_preds = vf_preds
    
    # 添加验证断言
    assert isinstance(vf_preds, torch.Tensor), "价值预测必须是Tensor"
    assert vf_preds.ndim == 1, f"价值预测形状应为(BATCH,),实际为{vf_preds.shape}"

    output = {
        Columns.ACTIONS: action,  # 选择的动作索引
        Columns.ACTION_DIST_INPUTS: final_logits,  # 返回 logits 作为 action_dist_inputs
        Columns.VF_PREDS: vf_preds,
        Columns.ACTION_LOGP: action_logp,  # 动作的对数概率
        # Columns.ADVANTAGES: advantages,  # 返回优势值
    }
    return output


def value_function(self):
    """Required method for PPO to access value function predictions."""
    return self._last_vf_preds


def get_initial_state(self) -> dict:
    return {}

Hello, I have the same problem. Did you solve this issue? I tried to deal with it by defining a custom post-processing function, but it might be a bit messy.

    return {
        Columns.ACTION_LOGP: (dist_a1.logp(action)),
        Columns.ACTION_DIST_INPUTS: prior_out,
        Columns.ACTIONS: action,
        Columns.VF_PREDS: self._value_net(obs).squeeze(-1),
        Columns.ADVANTAGES: self._value_net(obs).squeeze(-1),
        Columns.VALUE_TARGETS: self._value_net(obs).squeeze(-1)
    }, You should return more to make it correct.
1 Like

@Amir_Tahmasbi @ZanhaPeng For PPO, you should follow the instructions for your modules. Make sure you have followed them correctly.
having @override(ValueFunctionAPI)
def compute_values(..)
is one of these requirements.
there is an example here in Rllib repo for LSTM.
I would also appreciate it if someone could explain why, in the LSTM example, we need to handle cases where embeddings is both None and not None in compute_values . What is the flow of method calls? This seems completely different from the old stack.

Personally, I don’t understand why we need _forward_train. Why isn’t _forward sufficient?