Hello Ray/RLlib Community,
I am encountering a ValueError when running a custom model with the PPO algorithm in RLlib. The error message suggests a mismatch in broadcastable shapes, specifically between torch.Size([32])
and torch.Size([4])
. I am seeking insights or suggestions to resolve this issue.
Here’s a summary of my setup:
- I’ve created a custom environment (
DummyEnv
) and a custom model (TemporalFusionTransformer
). - The environment has a multi-part observation space and a multi-part action space.
- The model processes these observations and generates actions, but I’m facing an error related to the action distribution.
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import ray
from ray import tune
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.algorithms.ppo import PPOConfig
from gymnasium.spaces import Box, Discrete, Dict
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.models.modelv2 import restore_original_dimensions
# Simplified Dummy Environment
class DummyEnv(gym.Env):
def __init__(self, config):
super(DummyEnv, self).__init__()
self.observation_space = Dict({
'market_data': Box(low=0, high=1, shape=(3, 100, 5), dtype=np.float32),
'news_embeddings': Box(low=-np.inf, high=np.inf, shape=(10, 1, 221), dtype=np.float32)
})
self.action_space = Dict({
"asset_selection": Discrete(3),
"order_amount": Box(low=-1, high=1, shape=(1,)),
"order_cancellation": Discrete(2)
})
def reset(self, *, seed=None, options=None):
return {
'market_data': np.random.rand(*self.observation_space['market_data'].shape),
'news_embeddings': np.random.rand(*self.observation_space['news_embeddings'].shape)
}, {}
def step(self, action):
obs = self.reset()[0]
reward = np.random.rand()
done = False
info = {}
truncated = False
return obs, reward, done, truncated, info
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_size, 1)
def forward(self, context):
attention_weights = torch.softmax(self.attention(context), dim=1)
weighted_context = attention_weights * context
return torch.sum(weighted_context, dim=1)
class TemporalFusionTransformer(TorchModelV2, nn.Module):
@override(ModelV2)
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super(TemporalFusionTransformer, self).__init__(obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
print('action_space: ', action_space),
self.order = list(sorted(self.action_space.spaces.keys()))
self._last_decision_output = None
# Retrieve custom model configuration
num_assets = model_config.get("num_assets", 3)
custom_config = model_config.get("custom_model_config", {})
news_embedding_shape = custom_config.get("news_embedding_shape")
financial_data_shape = custom_config.get("financial_data_shape")
hidden_size = custom_config.get("hidden_size", 256)
lstm_hidden_size = custom_config.get("lstm_hidden_size", 256)
dropout = custom_config.get("dropout", 0.4)
self.lstm_hidden_size = lstm_hidden_size
# Define layers and network architecture
self.news_lstm = nn.LSTM(input_size=news_embedding_shape[2], hidden_size=hidden_size, batch_first=True)
self.news_attention = Attention(hidden_size)
self.financial_variable_selection = nn.Linear(financial_data_shape[2], 5)
self.financial_encoder = nn.LSTM(input_size=5, hidden_size=hidden_size, batch_first=True)
self.financial_gate = nn.Sequential(nn.Linear(5, 5), nn.Sigmoid())
self.financial_attention = Attention(hidden_size)
self.news_norm = nn.BatchNorm1d(hidden_size)
self.financial_norm = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(p=dropout)
# Output layers for different actions
self.decoder_asset_selection = nn.Linear(lstm_hidden_size, num_assets) # One-hot output for asset selection
self.decoder_order_amount = nn.Linear(lstm_hidden_size, 2) # Outputs mean and std for order amount
self.decoder_order_cancellation = nn.Linear(lstm_hidden_size, 2) # One-hot output for order cancellation
self.value_head = nn.Linear(self.lstm_hidden_size, 1)
self.lstm_state_size = lstm_hidden_size
self.state = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
# Creating a __init__ function that acts as a passthrough and adding the warning
# there led to errors probably due to the multiple inheritance. We encountered
# the same error if we add the Deprecated decorator. We therefore add the
# deprecation warning here.
flat_inputs = input_dict["obs_flat"].float()
# Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
# as input_dict may have extra zero-padding beyond seq_lens.max().
# Use add_time_dimension to handle this
self.time_major = self.model_config.get("_time_major", False)
inputs = add_time_dimension(
flat_inputs,
seq_lens=seq_lens,
framework="torch",
time_major=self.time_major,
)
output, new_state = self.forward_rnn(inputs, state, seq_lens)
output = torch.reshape(output, [-1, self.num_outputs])
return output, new_state
@override(ModelV2)
def get_initial_state(self, batch_size=None):
if batch_size is None:
batch_size = 1
# Initialize hidden state (h) and cell state (c)
h = torch.zeros(1, batch_size, self.lstm_hidden_size)
c = torch.zeros(1, batch_size, self.lstm_hidden_size)
return [h, c]
def dict_action2preprocessed_action(self, action):
"""
Preprocesses action components.
- 'asset_selection': Convert logits to discrete actions.
- 'order_amount': Use as is (mean and std).
- 'order_cancellation': Convert logits to discrete actions.
"""
# Process 'asset_selection'
asset_selection =action['asset_selection'] #[batch_size, 3]
# Use 'order_amount' as is (mean and std)
order_amount = action['order_amount'] #[batch_size, 2]
# Process 'order_cancellation'
order_cancellation = action['order_cancellation'] #[batch_size, 2]
print("asset_selection_logits shape:", asset_selection.shape)
print("order_amount_output shape:", order_amount.shape)
print("order_cancellation_logits shape:", order_cancellation.shape)
flat_logits = torch.cat([asset_selection, order_amount, order_cancellation], dim=-1)
print("Concatenated logits shape:", flat_logits.shape)
return flat_logits
def forward_rnn(self, inputs, state, seq_lens):
# inputs: Tensor of shape [B, T, obs_size]
#print('inputs', inputs)
original_obs = restore_original_dimensions(inputs, self.obs_space.original_space, "torch")
#print('original_obs: ', original_obs)
#original_obs = restore_original_dimensions(inputs["obs"], self.obs_space.original_space, "[tf|torch]")
news_embeddings_trough_time = original_obs['news_embeddings']
financial_data_trough_time = original_obs['market_data']
batch_size, seq_len, _ = inputs.size()
outputs = []
# Initialize or reset state according to the current batch size
if self.state is None or self.state[0].size(1) != batch_size:
self.state = self.get_initial_state(batch_size)
for t in range(seq_len):
input_t = inputs[:, t, :]
# Reshape and split the flattened observation
news_embeddings = news_embeddings_trough_time[:, t, :]
financial_data = financial_data_trough_time[:, t, :]
#print(f'news_embeddings @t={t}: ', news_embeddings)
#print(f'financial_data @t={t}: ', financial_data)
news_embeddings = news_embeddings.squeeze(2)
#news_embeddings = input_t[:, :news_embedding_size].view(batch_size, -1, news_embedding_feature_size)
#financial_data = input_t[:, news_embedding_size:].view(batch_size, num_assets, -1)
# Process news context
# Process and attend to news context
news_output, _ = self.news_lstm(news_embeddings)
news_context = self.news_attention(news_output)
# Financial variable selection and gating
selected_financial_data = self.financial_variable_selection(financial_data)
gated_output = self.financial_gate(financial_data)
# Multiplication operation
gated_financial_data = gated_output * selected_financial_data
batch_size, num_assets, seq_len, num_features = gated_financial_data.size()
gated_financial_data = gated_financial_data.view(batch_size * num_assets, seq_len, num_features)
# Process and attend to financial context
financial_output, _ = self.financial_encoder(gated_financial_data)
financial_context = self.financial_attention(financial_output)
news_context = self.news_norm(news_context)
financial_context = self.financial_norm(financial_context)
news_context_flat = news_context.view(batch_size, -1)
financial_context_flat = financial_context.view(batch_size, -1)
combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
fusion_input_size = news_context_flat.shape[1] + financial_context_flat.shape[1]
# Print shapes for debugging
print('Flattened news context shape:', news_context_flat.shape)
print('Flattened financial context shape:', financial_context_flat.shape)
print('Expected fusion input size:', fusion_input_size)
combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
print("Combined context size (before fusion gate):", combined_context.shape)
# Dynamically initialize fusion_gate and decision_lstm if not done already
if not hasattr(self, 'fusion_gate') or not hasattr(self, 'decision_lstm'):
fusion_input_size = combined_context.shape[1]
# Initialize fusion gate
self.fusion_gate = nn.Sequential(
nn.Linear(fusion_input_size, fusion_input_size),
nn.Sigmoid()
).to(combined_context.device) # Ensure the module is on the correct device
# Initialize LSTM for decision-making
self.decision_lstm = nn.LSTM(
input_size=fusion_input_size,
hidden_size=self.lstm_hidden_size,
batch_first=True
).to(combined_context.device) # Ensure the module is on the correct device
# Apply fusion gate
gate = self.fusion_gate(combined_context)
combined_context = gate * combined_context
combined_context = self.dropout(combined_context)
combined_context = combined_context.unsqueeze(1) # Add time dimension
print('shape of combined_context before decision_output : ', combined_context.shape)
#print('shape of self.state before decision_output : ', self.state.shape)
# LSTM for action decision
decision_output, new_lstm_state = self.decision_lstm(combined_context, self.state)
self.state = new_lstm_state
decision_output = decision_output.squeeze(1) # Remove the time dimension
# Set _last_decision_output to the output of the LSTM layer
self._last_decision_output = decision_output
# Get logits for each action
asset_selection_logits = self.decoder_asset_selection(decision_output)
# Split the order_amount output into mean and std
order_amount_output = self.decoder_order_amount(decision_output)
order_amount_mean = order_amount_output[:, :1] # First value for mean
order_amount_std = nn.functional.softplus(order_amount_output[:, 1:]) # Second value for std (ensuring it's positive)
# Concatenate mean and std for order_amount
order_amount = torch.cat([order_amount_mean, order_amount_std], dim=-1)
order_cancellation_logits = self.decoder_order_cancellation(decision_output)
# Apply softmax to get probabilities (one-hot encoded format)
asset_selection = torch.nn.functional.softmax(asset_selection_logits, dim=-1)
order_cancellation = torch.nn.functional.softmax(order_cancellation_logits, dim=-1)
# Construct action dictionary for this time step
action_dict = {
"asset_selection": asset_selection,
"order_amount": order_amount,
"order_cancellation": order_cancellation
}
# Append the action dictionary to the outputs list
print('self.num_outputs: ', self.num_outputs)
action = self.dict_action2preprocessed_action(action_dict)
return action, state
def value_function(self):
return self.value_head(self._last_decision_output).squeeze(1)
ModelCatalog.register_custom_model("TemporalFusionTransformer", TemporalFusionTransformer)
# Define separate environment configurations for training and evaluation
env_config_training = {
"test_env": False
}
env_config_evaluation = {
"test_env": True
}
# Register the environment
register_env("DummyEnv", DummyEnv)
hyperparam_space = {
"lr": 1e-5, #tune.loguniform(1e-5, 1e-2),
"train_batch_size": 128, #128,
"gamma": 0.9, #tune.uniform(0.9, 0.999),
"exploration_config": {
"type": "EpsilonGreedy",
"initial_epsilon": 0.1, #tune.uniform(0.1, 1.0),
"final_epsilon": 0.01,#tune.uniform(0.01, 0.1),
"epsilon_timesteps": 10000,#tune.choice([10000, 20000, 50000]),
},
"model": {
"custom_model": "TemporalFusionTransformer",
"custom_model_config": {
"news_embedding_shape": (10, 1, 221),
"financial_data_shape": (3, 100, 5),
"hidden_size": 64, #tune.choice([32, 64, 128]),
"dropout": 0.1, #tune.uniform(0.1, 0.5),
"num_lstm_layers": 1, #tune.choice([1, 2, 3]),
"lstm_hidden_size": 32, #tune.choice([32, 64, 128]),
"num_assets": 3,
}
}
}
# Configuration for PPO with your custom model and hyperparameter space
config = {
"framework": "torch",
"env": "DummyEnv",
"num_workers": 0,
"num_envs_per_worker": 1,
"rollout_fragment_length": 128,
"batch_mode": "truncate_episodes",
**hyperparam_space
}
torch.set_num_threads(1)
ray.init(num_cpus=1)
analysis = tune.run(
"PPO", # Or your chosen algorithm
config=config
)
ray.tune.error.TuneError: ('Trials did not complete', [PPO_DummyEnv_93348_00000])
(PPO pid=11876) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=11876, ip=127.0.0.1, actor_id=8a655176ffec8fee7bc93a5e01000000, repr=PPO)
(PPO pid=11876) File "python\ray\_raylet.pyx", line 1813, in ray._raylet.execute_task
(PPO pid=11876) File "python\ray\_raylet.pyx", line 1754, in ray._raylet.execute_task.function_executor
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\_private\function_manager.py", line 726, in actor_method_executor
(PPO pid=11876) return method(__ray_actor, *args, **kwargs)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=11876) return method(self, *_args, **_kwargs)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 516, in __init__
(PPO pid=11876) super().__init__(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\tune\trainable\trainable.py", line 161, in __init__
(PPO pid=11876) self.setup(copy.deepcopy(self.config))
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=11876) return method(self, *_args, **_kwargs)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 638, in setup
(PPO pid=11876) self.workers = WorkerSet(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 159, in __init__
(PPO pid=11876) self._setup(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 249, in _setup
(PPO pid=11876) self._local_worker = self._make_worker(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 950, in _make_worker
(PPO pid=11876) worker = cls(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=11876) self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1746, in _update_policy_map
(PPO pid=11876) self._build_policy_map(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1857, in _build_policy_map
(PPO pid=11876) new_policy = create_policy_for_framework(
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=11876) return policy_class(observation_space, action_space, merged_config)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=11876) self._initialize_loss_from_dummy_batch()
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\policy\policy.py", line 1520, in _initialize_loss_from_dummy_batch
(PPO pid=11876) self.loss(self.model, self.dist_class, train_batch)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=11876) curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 553, in logp
(PPO pid=11876) flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\tree\__init__.py", line 435, in map_structure
(PPO pid=11876) [func(*args) for args in zip(*map(flatten, structures))])
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\tree\__init__.py", line 435, in <listcomp>
(PPO pid=11876) [func(*args) for args in zip(*map(flatten, structures))])
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 549, in map_
(PPO pid=11876) return dist.logp(val)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=11876) return self.dist.log_prob(actions)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\torch\distributions\categorical.py", line 137, in log_prob
(PPO pid=11876) self._validate_sample(value)
(PPO pid=11876) File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\torch\distributions\distribution.py", line 297, in _validate_sample
(PPO pid=11876) raise ValueError(
(PPO pid=11876) ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32]) vs torch.Size([4]).
In this setup, num_outputs is 7.
I have checked the shape of the action returned by the forward_rnn fonction and it seems it is [batch_size, 7] wich is correct, if i understand correctly.
I would greatly appreciate any guidance or suggestions on how to resolve this error. It seems to be related to the shape of the actions being returned by the model, but I am unsure how to correctly format them for the PPO algorithm.
Thank you in advance for your help!