Encountering dimensional issues when porting LSTM to CNN+PPO

1. Severity of the issue: (select one)

High: Completely blocks me.

2. Environment:

  • Ray version: 2.40.0
  • Python version: 3.9
  • OS: ubuntu
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected:
  • Actual:

When I encountered dimension issues while porting LSTM to CNN+PPO, the code and error are as follows:
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()

定义 CNNModule 用于特征提取

class CNNModule(nn.Module):
def init(self, input_channels=1, output_dim=128):
super(CNNModule, self).init()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(p=0.5)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, output_dim)

def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = self.pool1(x)
    x = F.relu(self.bn3(self.conv3(x)))
    x = self.pool2(x)
    x = self.global_pool(x)
    x = x.view(x.size(0), -1)
    x = self.dropout(x)
    x = F.relu(self.fc(x))
    return x

修改后的 ParametricActionsRLModule

class ParametricActionsRLModule(TorchRLModule, ValueFunctionAPI):
def init(
self,
*,
observation_space,
action_space,
model_config,
inference_only=False,
learner_only=False,
catalog_class=None
):
super().init(
observation_space=observation_space,
action_space=action_space,
model_config=model_config,
inference_only=inference_only,
learner_only=learner_only,
catalog_class=catalog_class,
)
self.observation_space = observation_space
self.action_space = action_space
self.model_config = model_config
self.catalog_class = catalog_class
self.setup()

@override(TorchRLModule)
def setup(self):
    """初始化网络结构"""
    # 配置参数
    self.xy_max = self.model_config.get("xy_max", 20)
    input_channels = self.observation_space["cargo_obs"].shape[-1]  # 假设 [H, W, C]
    cnn_output_dim = self.model_config.get("cnn_output_dim", 128)
    lstm_hidden_size = self.model_config.get("lstm_hidden_size", 256)

    # CNN 模块
    self.cnn = CNNModule(input_channels=input_channels, output_dim=cnn_output_dim)

    # LSTM 模块
    self._lstm = nn.LSTM(cnn_output_dim, lstm_hidden_size, batch_first=True)

    # 组合维度:LSTM 输出 + xy 的 one-hot 编码
    combined_dim = lstm_hidden_size + 2 * self.xy_max

    # 动作 logits 和状态值层
    self._logits = nn.Linear(combined_dim, self.action_space.n)
    self._values = nn.Linear(combined_dim, 1)

@override(TorchRLModule)
def get_initial_state(self) -> Dict[str, np.ndarray]:
    """返回 LSTM 的初始状态"""
    return {
        "h": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
        "c": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
    }

def _process_sequence(self, obs_dict, state_in):
    """处理序列数据,返回 embeddings 和 state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, T, H, W, C]
    B, T, H, W, C = cargo_obs.shape
    cargo_obs = cargo_obs.view(B * T, H, W, C).permute(0, 3, 1, 2)  # [B*T, C, H, W]
    features = self.cnn(cargo_obs).view(B, T, -1)  # [B, T, cnn_output_dim]

    h_in = state_in["h"].unsqueeze(0)  # [1, B, hidden_size]
    c_in = state_in["c"].unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, T, hidden_size]

    xy = obs_dict["xy"]  # [B, T, 2]
    onehot_x = F.one_hot(xy[:, :, 0].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_y = F.one_hot(xy[:, :, 1].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=2)  # [B, T, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, T, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

def _process_single_timestep(self, obs_dict, state_in):
    """处理单时间步数据,返回 embeddings 和 state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, H, W, C]
    if len(cargo_obs.shape) == 3:
        cargo_obs = cargo_obs.unsqueeze(0)
    B = cargo_obs.shape[0]
    cargo_obs = cargo_obs.permute(0, 3, 1, 2)  # [B, C, H, W]
    features = self.cnn(cargo_obs).unsqueeze(1)  # [B, 1, cnn_output_dim]

    h_in = state_in["h"].unsqueeze(0)  # [1, B, hidden_size]
    c_in = state_in["c"].unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, 1, hidden_size]

    xy = obs_dict["xy"]  # [B, 2]
    if len(xy.shape) == 1:
        xy = xy.unsqueeze(0)
    onehot_x = F.one_hot(xy[:, 0].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_y = F.one_hot(xy[:, 1].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=1).unsqueeze(1)  # [B, 1, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, 1, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

@override(TorchRLModule)
def _forward(self, batch, **kwargs):
    """通用前向传播,处理序列数据"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)
    
    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions]
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.STATE_OUT: state_out,
    }

@override(TorchRLModule)
def _forward_train(self, batch, **kwargs):
    """训练时的前向传播,返回 embeddings 和 logits"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)
    
    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions]
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.EMBEDDINGS: combined,
        Columns.STATE_OUT: state_out,
    }

@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
    """推理时的前向传播,处理单时间步"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    
    # 处理可能的 numpy 输入
    for key in ["cargo_obs", "xy", "action_mask"]:
        if isinstance(obs_dict[key], np.ndarray):
            obs_dict[key] = torch.tensor(obs_dict[key])
    
    combined, state_out = self._process_single_timestep(obs_dict, state_in)
    combined = combined.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    logits = self._logits(combined)  # [B, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, num_actions]
    if len(action_mask.shape) == 1:
        action_mask = action_mask.unsqueeze(0)
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    actions = torch.argmax(logits, dim=-1)  # [B]
    if actions.shape[0] == 1:
        actions = actions.squeeze(0)
    
    return {
        Columns.ACTIONS: actions,
        Columns.STATE_OUT: state_out,
    }

@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings: Optional[TensorType] = None) -> TensorType:
    """计算状态值,支持序列和单时间步"""
    if embeddings is None:
        obs_dict = batch[Columns.OBS]
        state_in = batch[Columns.STATE_IN]
        if len(obs_dict["cargo_obs"].shape) == 5:  # [B, T, H, W, C]
            embeddings, _ = self._process_sequence(obs_dict, state_in)
        else:  # [B, H, W, C]
            embeddings, _ = self._process_single_timestep(obs_dict, state_in)
            embeddings = embeddings.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    values = self._values(embeddings).squeeze(-1)  # [B, T] 或 [B]
    return values

error:Failure # 1 (occurred at 2025-04-14_14-40-54)
e[36mray::PPO.train()e[39m (pid=25180, ip=127.0.0.1, actor_id=739b80bead9d0dc2fc30c13101000000, repr=PPO(env=my_ShuttleGridEnv; env-runners=4; learners=1; multi-agent=False))
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 999, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3350, in _run_one_training_iteration
training_step_return_value = self.training_step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 428, in training_step
learner_results = self.learner_group.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 327, in update_from_episodes
return self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 601, in _update
results = self._get_results(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 615, in _get_results
raise result_or_error
ray.exceptions.RayTaskError(ValueError): e[36mray::_WrappedExecutable.apply()e[39m (pid=20588, ip=127.0.0.1, actor_id=b9d647c42386e582359a42c201000000, repr=<ray.train._internal.worker_group._WrappedExecutable object at 0x00000235BFEA95B0>)
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1641, in apply
return func(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 385, in _learner_update
result = _learner.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1086, in update_from_episodes
self._update_from_batch_or_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1351, in update_from_batch_or_episodes
batch = self.learner_connector(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\connectors\learner\learner_connector_pipeline.py”, line 38, in call
ret = super().call(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\connectors\connector_pipeline_v2.py”, line 111, in call
batch = connector(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\connectors\common\batch_individual_items.py”, line 180, in call
else batch_fn(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\utils\spaces\space_utils.py”, line 373, in batch
ret = tree.map_structure(lambda *s: np_func(s, axis=0), *list_of_structs)
File "D:\Anaconda\envs\singleray\lib\site-packages\tree_init
.py", line 435, in map_structure
[func(*args) for args in zip(*map(flatten, structures))])
File "D:\Anaconda\envs\singleray\lib\site-packages\tree_init
.py", line 435, in
[func(*args) for args in zip(*map(flatten, structures))])
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\utils\spaces\space_utils.py”, line 373, in
ret = tree.map_structure(lambda *s: np_func(s, axis=0), *list_of_structs)
File “D:\Anaconda\envs\singleray\lib\site-packages\numpy\core\shape_base.py”, line 449, in stack
raise ValueError(‘all input arrays must have the same shape’)
ValueError: all input arrays must have the same shape
I sincerely hope to receive assistance, and I would be eternally grateful.