How to integrate LSTM into CNN+PPO

1. Severity of the issue: (select one)
High: Completely blocks me.

2. Environment:

  • Ray version: 2.40
  • Python version: 3.9
  • OS: ubuntu
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

I want to add lstm to the CNN+PPO model, but I have encountered problems such as dimension errors
ValueError: all input arrays must have the same shape), taking actor 0 out of service.

RuntimeError: Tensors must have same number of dimensions: got 2 and 3

Hi @ZanhaPengm, thanks for raising this issue. Could you provide a simple repro for us?

1 Like

thanks for your replay.I will attach part of my code along with the error messages for reference.
code:
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()

Define CNNModule for feature extraction

class CNNModule(nn.Module):
def init(self, input_channels=1, output_dim=128):
super(CNNModule, self).init()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(p=0.5)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, output_dim)

def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = self.pool1(x)
    x = F.relu(self.bn3(self.conv3(x)))
    x = self.pool2(x)
    x = self.global_pool(x)
    x = x.view(x.size(0), -1)
    x = self.dropout(x)
    x = F.relu(self.fc(x))
    return x

Modified ParametricActionsRLModule

class ParametricActionsRLModule(TorchRLModule, ValueFunctionAPI):
def init(
self,
*,
observation_space,
action_space,
model_config,
inference_only=False,
learner_only=False,
catalog_class=None
):
super().init(
observation_space=observation_space,
action_space=action_space,
model_config=model_config,
inference_only=inference_only,
learner_only=learner_only,
catalog_class=catalog_class,
)
self.observation_space = observation_space
self.action_space = action_space
self.model_config = model_config
self.catalog_class = catalog_class
self.setup()

@override(TorchRLModule)
def setup(self):
    """Initialize network structure"""
    # Configuration parameters
    self.xy_max = self.model_config.get("xy_max", 20)
    input_channels = self.observation_space["cargo_obs"].shape[-1]  # Assuming [H, W, C]
    cnn_output_dim = self.model_config.get("cnn_output_dim", 128)
    lstm_hidden_size = self.model_config.get("lstm_hidden_size", 256)

    # CNN module
    self.cnn = CNNModule(input_channels=input_channels, output_dim=cnn_output_dim)

    # LSTM module
    self._lstm = nn.LSTM(cnn_output_dim, lstm_hidden_size, batch_first=True)

    # Combined dimension: LSTM output + one-hot encoded xy
    combined_dim = lstm_hidden_size + 2 * self.xy_max

    # Action logits and value layers
    self._logits = nn.Linear(combined_dim, self.action_space.n)
    self._values = nn.Linear(combined_dim, 1)

@override(TorchRLModule)
def get_initial_state(self) -> Dict[str, np.ndarray]:
    """Return initial LSTM state"""
    return {
        "h": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
        "c": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
    }

def _process_sequence(self, obs_dict, state_in):
    """Process sequence data, return embeddings and state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, T, H, W, C]
    B, T, H, W, C = cargo_obs.shape
    cargo_obs = cargo_obs.view(B * T, H, W, C).permute(0, 3, 1, 2)  # [B*T, C, H, W]
    features = self.cnn(cargo_obs).view(B, T, -1)  # [B, T, cnn_output_dim]

    # Adjust hidden state to match current batch size
    h_in = state_in["h"]  # [B_state, hidden_size]
    c_in = state_in["c"]  # [B_state, hidden_size]
    if h_in.size(0) != B:
        h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
        c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)

    h_in = h_in.unsqueeze(0)  # [1, B, hidden_size]
    c_in = c_in.unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, T, hidden_size]

    xy = obs_dict["xy"]  # [B, T, 2]
    onehot_x = F.one_hot(xy[:, :, 0].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_y = F.one_hot(xy[:, :, 1].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=2)  # [B, T, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, T, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

def _process_single_timestep(self, obs_dict, state_in):
    """Process single timestep data, return embeddings and state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, H, W, C]
    if len(cargo_obs.shape) == 3:
        cargo_obs = cargo_obs.unsqueeze(0)
    B = cargo_obs.shape[0]
    cargo_obs = cargo_obs.permute(0, 3, 1, 2)  # [B, C, H, W]
    features = self.cnn(cargo_obs).unsqueeze(1)  # [B, 1, cnn_output_dim]

    # Adjust hidden state to match current batch size
    h_in = state_in["h"]  # [B_state, hidden_size]
    c_in = state_in["c"]  # [B_state, hidden_size]
    if h_in.size(0) != B:
        h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
        c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)

    h_in = h_in.unsqueeze(0)  # [1, B, hidden_size]
    c_in = c_in.unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, 1, hidden_size]

    xy = obs_dict["xy"]  # [B, 2]
    if len(xy.shape) == 1:
        xy = xy.unsqueeze(0)
    onehot_x = F.one_hot(xy[:, 0].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_y = F.one_hot(xy[:, 1].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=1).unsqueeze(1)  # [B, 1, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, 1, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

@override(TorchRLModule)
def _forward(self, batch, **kwargs):
    """General forward pass, process sequence data"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)
    
    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions]
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.STATE_OUT: state_out,
    }



@override(TorchRLModule)
def _forward_train(self, batch, **kwargs):
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)

    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions](确保三维)
    logits = logits.masked_fill(action_mask == 0, -1e9)

    # B, T, A = logits.shape  # A = self.num_actions
    # logits_flat = logits.reshape(B * T, A)  # [B*T, num_actions]
    
    # # 展平动作(关键修正)
    # actions = batch[Columns.ACTIONS]  # 输入形状应为 [B, T]
    # actions_flat = actions.reshape(B * T)  # 展平为 [B*T]
    
    # embeddings_flat = combined.reshape(B * T, combined.shape[2])  # [B*T, embed_dim]

    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.EMBEDDINGS: combined,
        Columns.STATE_OUT: state_out,
    }



@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
    """Forward pass for inference, process single timestep"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    
    # Handle possible numpy inputs
    for key in ["cargo_obs", "xy", "action_mask"]:
        if isinstance(obs_dict[key], np.ndarray):
            obs_dict[key] = torch.tensor(obs_dict[key])
    
    combined, state_out = self._process_single_timestep(obs_dict, state_in)
    combined = combined.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    logits = self._logits(combined)  # [B, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, num_actions]
    if len(action_mask.shape) == 1:
        action_mask = action_mask.unsqueeze(0)
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    actions = torch.argmax(logits, dim=-1)  # [B]
    if actions.shape[0] == 1:
        actions = actions.squeeze(0)
    
    return {
        Columns.ACTIONS: actions,
        Columns.STATE_OUT: state_out,
    }

@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings: Optional[TensorType] = None) -> TensorType:
    """Compute state values, support sequences and single timesteps"""
    if embeddings is None:
        obs_dict = batch[Columns.OBS]
        state_in = batch[Columns.STATE_IN]
        if len(obs_dict["cargo_obs"].shape) == 5:  # [B, T, H, W, C]
            embeddings, _ = self._process_sequence(obs_dict, state_in)
        else:  # [B, H, W, C]
            embeddings, _ = self._process_single_timestep(obs_dict, state_in)
            embeddings = embeddings.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    values = self._values(embeddings).squeeze(-1)  # [B, T] or [B]
    return values

Failure # 1 (occurred at 2025-04-17_15-12-11)
e[36mray::PPO.train()e[39m (pid=13172, ip=127.0.0.1, actor_id=efe3bc76e838d6be71a0de8601000000, repr=PPO(env=my_ShuttleGridEnv; env-runners=1; learners=1; multi-agent=False))
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 999, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3350, in _run_one_training_iteration
training_step_return_value = self.training_step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 428, in training_step
learner_results = self.learner_group.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 327, in update_from_episodes
return self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 601, in _update
results = self._get_results(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 615, in _get_results
raise result_or_error
ray.exceptions.RayTaskError(ValueError): e[36mray::_WrappedExecutable.apply()e[39m (pid=24044, ip=127.0.0.1, actor_id=a8ea6098076b82b63b49a13501000000, repr=<ray.train._internal.worker_group._WrappedExecutable object at 0x000001CADB7195B0>)
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1641, in apply
return func(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 385, in _learner_update
result = _learner.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1086, in update_from_episodes
self._update_from_batch_or_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1423, in _update_from_batch_or_episodes
fwd_out, loss_per_module, tensor_metrics = self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 497, in _update
return self._possibly_compiled_update(batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 152, in _uncompiled_update
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 924, in compute_losses
loss = self.compute_loss_for_module(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\torch\ppo_torch_learner.py”, line 75, in compute_loss_for_module
curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\models\torch\torch_distributions.py”, line 38, in logp
return self._dist.log_prob(value, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\categorical.py”, line 137, in log_prob
self._validate_sample(value)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\distribution.py”, line 297, in _validate_sample
raise ValueError(
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([5, 50]) vs torch.Size([5, 100]).

Excuse me, I would like to inquire why the action T is 50 and the other T’s are 100.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([5, 50]) vs torch.Size([5, 100]).), taking actor 0 out of service.
(_WrappedExecutable pid=23340) cargo_obs shape: torch.Size([5, 100, 5, 5, 1])
(_WrappedExecutable pid=23340) actions shape: torch.Size([5, 50])
(_WrappedExecutable pid=23340) logits shape: torch.Size([5, 100, 5])
and the code is as followed:def _forward_train(self, batch, **kwargs):
obs_dict = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
combined, state_out = self._process_sequence(obs_dict, state_in)

    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions]
    logits = logits.masked_fill(action_mask == 0, -1e9)
    print(f"logits shape: {logits.shape}")
    print(f"cargo_obs shape: {obs_dict['cargo_obs'].shape}")
    print(f"actions shape: {batch[Columns.ACTIONS].shape}")

    return {
        Columns.ACTION_DIST_INPUTS: logits,  # [B, T, num_actions]
        Columns.EMBEDDINGS: combined,       # [B, T, hidden_size + 2*xy_max]
        Columns.STATE_OUT: state_out,
    }