Broadcast Issue when using customenv with custom policy

Hi I’m new to Ray and RLib,

I’m trying to prototype a workflow where I use PPO to train a transformer based embeddings model. I keep running into the same issue, where the build method complains about some output not being broadcastable. Everything I’ve tried to resole it doesn’t seem to work.

Note that the mismatch always seems to be half of the actual model dimension. For instance, if I use a model with embedding dim of size 768, it then complains about mismatch between torch.Size([32, 768]) and torch.Size([32, 384])

Here’s the stack trace

    algo = config.build()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm_config.py", line 1071, in build
    return algo_class(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm.py", line 475, in __init__
    super().__init__(
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/trainable.py", line 170, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm.py", line 601, in setup
    self.workers = WorkerSet(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 172, in __init__
    self._setup(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 262, in _setup
    self._local_worker = self._make_worker(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 967, in _make_worker
    worker = cls(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 738, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
    self._build_policy_map(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/utils/policy.py", line 142, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/policy/policy.py", line 1489, in _initialize_loss_from_dummy_batch
    self.loss(self.model, self.dist_class, train_batch)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 112, in loss
    curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/models/torch/torch_action_dist.py", line 268, in logp
    return super().logp(actions).sum(-1)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/models/torch/torch_action_dist.py", line 37, in logp
    return self.dist.log_prob(actions)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributions/normal.py", line 79, in log_prob
    self._validate_sample(value)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py", line 288, in _validate_sample
    raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 384]) vs torch.Size([32, 192]).

and here’s the dummy prototype code that’s producing it.

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from scipy.spatial.distance import cosine

import ray
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import get_trainable_cls
from ray.rllib.utils.annotations import override
import os
from ray.rllib.utils.framework import try_import_torch
import pdb
torch, nn = try_import_torch()


embedding_model_name  = "sentence-transformers/all-MiniLM-L6-v2"
embedding_dim = 384


class CustomEnv(gym.Env):
    def __init__(self, env_config):
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        self.model = AutoModel.from_pretrained(embedding_model_name)
        self.vocab_size = self.tokenizer.vocab_size
        self.max_length = 512
        self.observation_space = spaces.Box(low=0, high=self.vocab_size-1, shape=(self.max_length,), dtype=np.int32)
        self.action_space = spaces.Box(low=-1, high=1, shape=(embedding_dim,), dtype=np.float32) # assuming your model produces embeddings of size 768
        self.index = 0

        self.setup_batch_and_databse()
        self.reset()


    def setup_batch_and_databse(self):
        # Dummy methods
        batch = ["Hello world", "How are you doing today?", "I am doing well", "I am doing terrible", "I am doing great", "I am a fish", "I am a man", "How now brown cow"]
        dummy_embeddings_database = torch.randn((len(batch), embedding_dim))
        self.batch = batch
        self.embeddings_database = dummy_embeddings_database
            

    def reset(self,*, seed=None, options=None):
        if self.index >= len(self.batch):
            self.index = 0  # reset index if all strings have been used
        curren_string = self.batch[self.index]
        tokenized_string = self.tokenizer(curren_string, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")["input_ids"]
        self.index += 1
        return (tokenized_string.numpy().reshape(-1), {})

    def step(self, action):
        # action is the embedding vector produced by your policy
        top_doc_embedding = self.search_database(action)  # replace with logic to search your database using the action embedding
        reward = self.calculate_reward(action, top_doc_embedding)  # replace with logic to compute reward
        done = truncated = True
        dummy_observation = np.ones(self.max_length, dtype=np.int).reshape(-1)
        return dummy_observation, reward, done, truncated, {}

    def search_database(self, action):
        # Find the closest embedding in the database to the action embedding
        # let's just do a dot product  with the database for now
        action = torch.tensor(action)
        logits = torch.matmul(action, torch.tensor(self.embeddings_database).T)
        top_doc_idx = torch.argmax(logits)
        top_doc_embedding = self.embeddings_database[top_doc_idx]
        return top_doc_embedding


    def calculate_reward(self, action, top_doc_embedding):
        #dummy reward function
        reward = -1 * cosine(action, top_doc_embedding)
        return reward


class TransformerPolicy(TorchModelV2,nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)  # you should also initialize the nn.Module
        self.transformer = AutoModel.from_pretrained(embedding_model_name)
        self.value_layer = nn.Linear(embedding_dim, 1)

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].long()
        print(f">>>>>> Shape of obs is {obs.shape}")
        transformer_outputs = self.transformer(obs)
        action_out = transformer_outputs.pooler_output
        self._value_out = self.value_layer(action_out)

        print(f">>>>>> Shape of action_out is {action_out.shape}")
        print(f">>>>>> Shape of value_out is {self._value_out.shape}")      
        # assert all elements in action_out are in the range [-1,1]
        assert torch.all(action_out <= 1) and torch.all(action_out >= -1)  
        return action_out, []
    
    def value_function(self):
        values = self._value_out.squeeze(1)
        return values
        

if __name__ == "__main__":
    ray.init()

    ModelCatalog.register_custom_model("transformer_policy", TransformerPolicy)

    config = (
        get_trainable_cls("PPO")
        .get_default_config()
        .environment(CustomEnv, env_config={"foo": 5})
        .framework("torch")
        .rollouts(num_rollout_workers=1)
        .training(
            model={
                "custom_model": "transformer_policy",
            }
        )
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
        .rollouts(num_rollout_workers=0)
    )
    algo = config.build()

    ray.shutdown()

Hi @badwolf,

You are using a PPO algorithm on continuous space. If the model outputs a vector of size d, half of it is used as the logits for the mean and the other half is used as log of std of the Normal distribution over your actions. That’s why the distribution dimensions are actually 192 dimensional but your action space is 384 dimensional, and hence errors out on the mismatch of the shape. To fix this, your custom model should output a 2xaction_dim logit.

1 Like

Thanks, that was it!

I’m encountering a similar issue, although my actions are discrete. How should I solve this? Below I have attached my code and the error message.
Error:Failure # 1 (occurred at 2025-04-14_19-00-12)
e[36mray::PPO.train()e[39m (pid=11520, ip=127.0.0.1, actor_id=897c1c17d372e441bb571eff01000000, repr=PPO(env=my_ShuttleGridEnv; env-runners=1; learners=1; multi-agent=False))
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 999, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3350, in _run_one_training_iteration
training_step_return_value = self.training_step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 428, in training_step
learner_results = self.learner_group.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 327, in update_from_episodes
return self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 601, in _update
results = self._get_results(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 615, in _get_results
raise result_or_error
ray.exceptions.RayTaskError(ValueError): e[36mray::_WrappedExecutable.apply()e[39m (pid=2220, ip=127.0.0.1, actor_id=82648f7a478df06d4818e24901000000, repr=<ray.train._internal.worker_group._WrappedExecutable object at 0x000001DBEBEE9610>)
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1641, in apply
return func(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 385, in _learner_update
result = _learner.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1086, in update_from_episodes
self._update_from_batch_or_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1423, in _update_from_batch_or_episodes
fwd_out, loss_per_module, tensor_metrics = self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 497, in _update
return self._possibly_compiled_update(batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 152, in _uncompiled_update
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 924, in compute_losses
loss = self.compute_loss_for_module(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\torch\ppo_torch_learner.py”, line 75, in compute_loss_for_module
curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\models\torch\torch_distributions.py”, line 38, in logp
return self._dist.log_prob(value, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\categorical.py”, line 137, in log_prob
self._validate_sample(value)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\distribution.py”, line 297, in _validate_sample
raise ValueError(
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([5, 50]) vs torch.Size([5, 100]).
Code:
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()

Define CNNModule for feature extraction

class CNNModule(nn.Module):
def init(self, input_channels=1, output_dim=128):
super(CNNModule, self).init()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(p=0.5)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, output_dim)

def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = self.pool1(x)
    x = F.relu(self.bn3(self.conv3(x)))
    x = self.pool2(x)
    x = self.global_pool(x)
    x = x.view(x.size(0), -1)
    x = self.dropout(x)
    x = F.relu(self.fc(x))
    return x

Modified ParametricActionsRLModule

class ParametricActionsRLModule(TorchRLModule, ValueFunctionAPI):
def init(
self,
*,
observation_space,
action_space,
model_config,
inference_only=False,
learner_only=False,
catalog_class=None
):
super().init(
observation_space=observation_space,
action_space=action_space,
model_config=model_config,
inference_only=inference_only,
learner_only=learner_only,
catalog_class=catalog_class,
)
self.observation_space = observation_space
self.action_space = action_space
self.model_config = model_config
self.catalog_class = catalog_class
self.setup()

@override(TorchRLModule)
def setup(self):
    """Initialize network structure"""
    # Configuration parameters
    self.xy_max = self.model_config.get("xy_max", 20)
    input_channels = self.observation_space["cargo_obs"].shape[-1]  # Assuming [H, W, C]
    cnn_output_dim = self.model_config.get("cnn_output_dim", 128)
    lstm_hidden_size = self.model_config.get("lstm_hidden_size", 256)

    # CNN module
    self.cnn = CNNModule(input_channels=input_channels, output_dim=cnn_output_dim)

    # LSTM module
    self._lstm = nn.LSTM(cnn_output_dim, lstm_hidden_size, batch_first=True)

    # Combined dimension: LSTM output + one-hot encoded xy
    combined_dim = lstm_hidden_size + 2 * self.xy_max

    # Action logits and value layers
    self._logits = nn.Linear(combined_dim, self.action_space.n)
    self._values = nn.Linear(combined_dim, 1)

@override(TorchRLModule)
def get_initial_state(self) -> Dict[str, np.ndarray]:
    """Return initial LSTM state"""
    return {
        "h": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
        "c": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
    }

def _process_sequence(self, obs_dict, state_in):
    """Process sequence data, return embeddings and state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, T, H, W, C]
    B, T, H, W, C = cargo_obs.shape
    cargo_obs = cargo_obs.view(B * T, H, W, C).permute(0, 3, 1, 2)  # [B*T, C, H, W]
    features = self.cnn(cargo_obs).view(B, T, -1)  # [B, T, cnn_output_dim]

    # Adjust hidden state to match current batch size
    h_in = state_in["h"]  # [B_state, hidden_size]
    c_in = state_in["c"]  # [B_state, hidden_size]
    if h_in.size(0) != B:
        h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
        c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)

    h_in = h_in.unsqueeze(0)  # [1, B, hidden_size]
    c_in = c_in.unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, T, hidden_size]

    xy = obs_dict["xy"]  # [B, T, 2]
    onehot_x = F.one_hot(xy[:, :, 0].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_y = F.one_hot(xy[:, :, 1].long(), num_classes=self.xy_max)  # [B, T, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=2)  # [B, T, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, T, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

def _process_single_timestep(self, obs_dict, state_in):
    """Process single timestep data, return embeddings and state_out"""
    cargo_obs = obs_dict["cargo_obs"]  # [B, H, W, C]
    if len(cargo_obs.shape) == 3:
        cargo_obs = cargo_obs.unsqueeze(0)
    B = cargo_obs.shape[0]
    cargo_obs = cargo_obs.permute(0, 3, 1, 2)  # [B, C, H, W]
    features = self.cnn(cargo_obs).unsqueeze(1)  # [B, 1, cnn_output_dim]

    # Adjust hidden state to match current batch size
    h_in = state_in["h"]  # [B_state, hidden_size]
    c_in = state_in["c"]  # [B_state, hidden_size]
    if h_in.size(0) != B:
        h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
        c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)

    h_in = h_in.unsqueeze(0)  # [1, B, hidden_size]
    c_in = c_in.unsqueeze(0)  # [1, B, hidden_size]
    lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in))  # [B, 1, hidden_size]

    xy = obs_dict["xy"]  # [B, 2]
    if len(xy.shape) == 1:
        xy = xy.unsqueeze(0)
    onehot_x = F.one_hot(xy[:, 0].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_y = F.one_hot(xy[:, 1].long(), num_classes=self.xy_max)  # [B, xy_max]
    onehot_xy = torch.cat([onehot_x, onehot_y], dim=1).unsqueeze(1)  # [B, 1, 2*xy_max]

    combined = torch.cat([lstm_out, onehot_xy], dim=2)  # [B, 1, hidden_size + 2*xy_max]
    state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
    return combined, state_out

@override(TorchRLModule)
def _forward(self, batch, **kwargs):
    """General forward pass, process sequence data"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)
    
    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions]
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.STATE_OUT: state_out,
    }



@override(TorchRLModule)
def _forward_train(self, batch, **kwargs):
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    combined, state_out = self._process_sequence(obs_dict, state_in)

    logits = self._logits(combined)  # [B, T, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, T, num_actions](确保三维)
    logits = logits.masked_fill(action_mask == 0, -1e9)

    # B, T, A = logits.shape  # A = self.num_actions
    # logits_flat = logits.reshape(B * T, A)  # [B*T, num_actions]
    
    # # 展平动作(关键修正)
    # actions = batch[Columns.ACTIONS]  # 输入形状应为 [B, T]
    # actions_flat = actions.reshape(B * T)  # 展平为 [B*T]
    
    # embeddings_flat = combined.reshape(B * T, combined.shape[2])  # [B*T, embed_dim]

    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.EMBEDDINGS: combined,
        Columns.STATE_OUT: state_out,
    }



@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
    """Forward pass for inference, process single timestep"""
    obs_dict = batch[Columns.OBS]
    state_in = batch[Columns.STATE_IN]
    
    # Handle possible numpy inputs
    for key in ["cargo_obs", "xy", "action_mask"]:
        if isinstance(obs_dict[key], np.ndarray):
            obs_dict[key] = torch.tensor(obs_dict[key])
    
    combined, state_out = self._process_single_timestep(obs_dict, state_in)
    combined = combined.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    logits = self._logits(combined)  # [B, num_actions]
    action_mask = obs_dict["action_mask"]  # [B, num_actions]
    if len(action_mask.shape) == 1:
        action_mask = action_mask.unsqueeze(0)
    logits = logits.masked_fill(action_mask == 0, -1e9)
    
    actions = torch.argmax(logits, dim=-1)  # [B]
    if actions.shape[0] == 1:
        actions = actions.squeeze(0)
    
    return {
        Columns.ACTIONS: actions,
        Columns.STATE_OUT: state_out,
    }

@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings: Optional[TensorType] = None) -> TensorType:
    """Compute state values, support sequences and single timesteps"""
    if embeddings is None:
        obs_dict = batch[Columns.OBS]
        state_in = batch[Columns.STATE_IN]
        if len(obs_dict["cargo_obs"].shape) == 5:  # [B, T, H, W, C]
            embeddings, _ = self._process_sequence(obs_dict, state_in)
        else:  # [B, H, W, C]
            embeddings, _ = self._process_single_timestep(obs_dict, state_in)
            embeddings = embeddings.squeeze(1)  # [B, hidden_size + 2*xy_max]
    
    values = self._values(embeddings).squeeze(-1)  # [B, T] or [B]
    return values