thanks for your replay.I will attach part of my code along with the error messages for reference.
code:
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
Define CNNModule for feature extraction
class CNNModule(nn.Module):
def init(self, input_channels=1, output_dim=128):
super(CNNModule, self).init()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(p=0.5)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, output_dim)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool1(x)
x = F.relu(self.bn3(self.conv3(x)))
x = self.pool2(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = F.relu(self.fc(x))
return x
Modified ParametricActionsRLModule
class ParametricActionsRLModule(TorchRLModule, ValueFunctionAPI):
def init(
self,
*,
observation_space,
action_space,
model_config,
inference_only=False,
learner_only=False,
catalog_class=None
):
super().init(
observation_space=observation_space,
action_space=action_space,
model_config=model_config,
inference_only=inference_only,
learner_only=learner_only,
catalog_class=catalog_class,
)
self.observation_space = observation_space
self.action_space = action_space
self.model_config = model_config
self.catalog_class = catalog_class
self.setup()
@override(TorchRLModule)
def setup(self):
"""Initialize network structure"""
# Configuration parameters
self.xy_max = self.model_config.get("xy_max", 20)
input_channels = self.observation_space["cargo_obs"].shape[-1] # Assuming [H, W, C]
cnn_output_dim = self.model_config.get("cnn_output_dim", 128)
lstm_hidden_size = self.model_config.get("lstm_hidden_size", 256)
# CNN module
self.cnn = CNNModule(input_channels=input_channels, output_dim=cnn_output_dim)
# LSTM module
self._lstm = nn.LSTM(cnn_output_dim, lstm_hidden_size, batch_first=True)
# Combined dimension: LSTM output + one-hot encoded xy
combined_dim = lstm_hidden_size + 2 * self.xy_max
# Action logits and value layers
self._logits = nn.Linear(combined_dim, self.action_space.n)
self._values = nn.Linear(combined_dim, 1)
@override(TorchRLModule)
def get_initial_state(self) -> Dict[str, np.ndarray]:
"""Return initial LSTM state"""
return {
"h": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
"c": np.zeros(shape=(self._lstm.hidden_size,), dtype=np.float32),
}
def _process_sequence(self, obs_dict, state_in):
"""Process sequence data, return embeddings and state_out"""
cargo_obs = obs_dict["cargo_obs"] # [B, T, H, W, C]
B, T, H, W, C = cargo_obs.shape
cargo_obs = cargo_obs.view(B * T, H, W, C).permute(0, 3, 1, 2) # [B*T, C, H, W]
features = self.cnn(cargo_obs).view(B, T, -1) # [B, T, cnn_output_dim]
# Adjust hidden state to match current batch size
h_in = state_in["h"] # [B_state, hidden_size]
c_in = state_in["c"] # [B_state, hidden_size]
if h_in.size(0) != B:
h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)
h_in = h_in.unsqueeze(0) # [1, B, hidden_size]
c_in = c_in.unsqueeze(0) # [1, B, hidden_size]
lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in)) # [B, T, hidden_size]
xy = obs_dict["xy"] # [B, T, 2]
onehot_x = F.one_hot(xy[:, :, 0].long(), num_classes=self.xy_max) # [B, T, xy_max]
onehot_y = F.one_hot(xy[:, :, 1].long(), num_classes=self.xy_max) # [B, T, xy_max]
onehot_xy = torch.cat([onehot_x, onehot_y], dim=2) # [B, T, 2*xy_max]
combined = torch.cat([lstm_out, onehot_xy], dim=2) # [B, T, hidden_size + 2*xy_max]
state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
return combined, state_out
def _process_single_timestep(self, obs_dict, state_in):
"""Process single timestep data, return embeddings and state_out"""
cargo_obs = obs_dict["cargo_obs"] # [B, H, W, C]
if len(cargo_obs.shape) == 3:
cargo_obs = cargo_obs.unsqueeze(0)
B = cargo_obs.shape[0]
cargo_obs = cargo_obs.permute(0, 3, 1, 2) # [B, C, H, W]
features = self.cnn(cargo_obs).unsqueeze(1) # [B, 1, cnn_output_dim]
# Adjust hidden state to match current batch size
h_in = state_in["h"] # [B_state, hidden_size]
c_in = state_in["c"] # [B_state, hidden_size]
if h_in.size(0) != B:
h_in = h_in[:B, :] if h_in.size(0) > B else torch.cat([h_in, torch.zeros(B - h_in.size(0), h_in.size(1), device=h_in.device)], dim=0)
c_in = c_in[:B, :] if c_in.size(0) > B else torch.cat([c_in, torch.zeros(B - c_in.size(0), c_in.size(1), device=c_in.device)], dim=0)
h_in = h_in.unsqueeze(0) # [1, B, hidden_size]
c_in = c_in.unsqueeze(0) # [1, B, hidden_size]
lstm_out, (h_out, c_out) = self._lstm(features, (h_in, c_in)) # [B, 1, hidden_size]
xy = obs_dict["xy"] # [B, 2]
if len(xy.shape) == 1:
xy = xy.unsqueeze(0)
onehot_x = F.one_hot(xy[:, 0].long(), num_classes=self.xy_max) # [B, xy_max]
onehot_y = F.one_hot(xy[:, 1].long(), num_classes=self.xy_max) # [B, xy_max]
onehot_xy = torch.cat([onehot_x, onehot_y], dim=1).unsqueeze(1) # [B, 1, 2*xy_max]
combined = torch.cat([lstm_out, onehot_xy], dim=2) # [B, 1, hidden_size + 2*xy_max]
state_out = {"h": h_out.squeeze(0), "c": c_out.squeeze(0)}
return combined, state_out
@override(TorchRLModule)
def _forward(self, batch, **kwargs):
"""General forward pass, process sequence data"""
obs_dict = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
combined, state_out = self._process_sequence(obs_dict, state_in)
logits = self._logits(combined) # [B, T, num_actions]
action_mask = obs_dict["action_mask"] # [B, T, num_actions]
logits = logits.masked_fill(action_mask == 0, -1e9)
return {
Columns.ACTION_DIST_INPUTS: logits,
Columns.STATE_OUT: state_out,
}
@override(TorchRLModule)
def _forward_train(self, batch, **kwargs):
obs_dict = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
combined, state_out = self._process_sequence(obs_dict, state_in)
logits = self._logits(combined) # [B, T, num_actions]
action_mask = obs_dict["action_mask"] # [B, T, num_actions](确保三维)
logits = logits.masked_fill(action_mask == 0, -1e9)
# B, T, A = logits.shape # A = self.num_actions
# logits_flat = logits.reshape(B * T, A) # [B*T, num_actions]
# # 展平动作(关键修正)
# actions = batch[Columns.ACTIONS] # 输入形状应为 [B, T]
# actions_flat = actions.reshape(B * T) # 展平为 [B*T]
# embeddings_flat = combined.reshape(B * T, combined.shape[2]) # [B*T, embed_dim]
return {
Columns.ACTION_DIST_INPUTS: logits,
Columns.EMBEDDINGS: combined,
Columns.STATE_OUT: state_out,
}
@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
"""Forward pass for inference, process single timestep"""
obs_dict = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
# Handle possible numpy inputs
for key in ["cargo_obs", "xy", "action_mask"]:
if isinstance(obs_dict[key], np.ndarray):
obs_dict[key] = torch.tensor(obs_dict[key])
combined, state_out = self._process_single_timestep(obs_dict, state_in)
combined = combined.squeeze(1) # [B, hidden_size + 2*xy_max]
logits = self._logits(combined) # [B, num_actions]
action_mask = obs_dict["action_mask"] # [B, num_actions]
if len(action_mask.shape) == 1:
action_mask = action_mask.unsqueeze(0)
logits = logits.masked_fill(action_mask == 0, -1e9)
actions = torch.argmax(logits, dim=-1) # [B]
if actions.shape[0] == 1:
actions = actions.squeeze(0)
return {
Columns.ACTIONS: actions,
Columns.STATE_OUT: state_out,
}
@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings: Optional[TensorType] = None) -> TensorType:
"""Compute state values, support sequences and single timesteps"""
if embeddings is None:
obs_dict = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
if len(obs_dict["cargo_obs"].shape) == 5: # [B, T, H, W, C]
embeddings, _ = self._process_sequence(obs_dict, state_in)
else: # [B, H, W, C]
embeddings, _ = self._process_single_timestep(obs_dict, state_in)
embeddings = embeddings.squeeze(1) # [B, hidden_size + 2*xy_max]
values = self._values(embeddings).squeeze(-1) # [B, T] or [B]
return values
Failure # 1 (occurred at 2025-04-17_15-12-11)
e[36mray::PPO.train()e[39m (pid=13172, ip=127.0.0.1, actor_id=efe3bc76e838d6be71a0de8601000000, repr=PPO(env=my_ShuttleGridEnv; env-runners=1; learners=1; multi-agent=False))
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 999, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3350, in _run_one_training_iteration
training_step_return_value = self.training_step()
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 428, in training_step
learner_results = self.learner_group.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 327, in update_from_episodes
return self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 601, in _update
results = self._get_results(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 615, in _get_results
raise result_or_error
ray.exceptions.RayTaskError(ValueError): e[36mray::_WrappedExecutable.apply()e[39m (pid=24044, ip=127.0.0.1, actor_id=a8ea6098076b82b63b49a13501000000, repr=<ray.train._internal.worker_group._WrappedExecutable object at 0x000001CADB7195B0>)
File “python\ray_raylet.pyx”, line 1883, in ray._raylet.execute_task
File “python\ray_raylet.pyx”, line 1824, in ray._raylet.execute_task.function_executor
File “D:\Anaconda\envs\singleray\lib\site-packages\ray_private\function_manager.py”, line 696, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1641, in apply
return func(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner_group.py”, line 385, in _learner_update
result = _learner.update_from_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1086, in update_from_episodes
self._update_from_batch_or_episodes(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 1423, in _update_from_batch_or_episodes
fwd_out, loss_per_module, tensor_metrics = self._update(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 497, in _update
return self._possibly_compiled_update(batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\torch\torch_learner.py”, line 152, in _uncompiled_update
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\core\learner\learner.py”, line 924, in compute_losses
loss = self.compute_loss_for_module(
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\util\tracing\tracing_helper.py”, line 463, in _resume_span
return method(self, *_args, **_kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\algorithms\ppo\torch\ppo_torch_learner.py”, line 75, in compute_loss_for_module
curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
File “D:\Anaconda\envs\singleray\lib\site-packages\ray\rllib\models\torch\torch_distributions.py”, line 38, in logp
return self._dist.log_prob(value, **kwargs)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\categorical.py”, line 137, in log_prob
self._validate_sample(value)
File “D:\Anaconda\envs\singleray\lib\site-packages\torch\distributions\distribution.py”, line 297, in _validate_sample
raise ValueError(
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([5, 50]) vs torch.Size([5, 100]).