Undestanding the expected output shapes of a Recurrent model with Dict Action Space

Hello Ray/RLlib Community,

I am encountering a ValueError when running a custom model with the PPO algorithm in RLlib. The error message suggests a mismatch in broadcastable shapes, specifically between torch.Size([32]) and torch.Size([4]). I am seeking insights or suggestions to resolve this issue.

Here’s a summary of my setup:

  • I’ve created a custom environment (DummyEnv) and a custom model (TemporalFusionTransformer).
  • The environment has a multi-part observation space and a multi-part action space.
  • The model processes these observations and generates actions, but I’m facing an error related to the action distribution.
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import ray
from ray import tune
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.algorithms.ppo import PPOConfig
from gymnasium.spaces import Box, Discrete, Dict
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.models.modelv2 import restore_original_dimensions

# Simplified Dummy Environment
class DummyEnv(gym.Env):
    def __init__(self, config):
        super(DummyEnv, self).__init__()
        self.observation_space = Dict({
            'market_data': Box(low=0, high=1, shape=(3, 100, 5), dtype=np.float32),
            'news_embeddings': Box(low=-np.inf, high=np.inf, shape=(10, 1, 221), dtype=np.float32)
        })
        self.action_space = Dict({
            "asset_selection": Discrete(3),
            "order_amount": Box(low=-1, high=1, shape=(1,)),
            "order_cancellation": Discrete(2)
        })

    def reset(self, *, seed=None, options=None):
        return {
            'market_data': np.random.rand(*self.observation_space['market_data'].shape),
            'news_embeddings': np.random.rand(*self.observation_space['news_embeddings'].shape)
        }, {}

    def step(self, action):
        obs = self.reset()[0]
        reward = np.random.rand()
        done = False
        info = {}
        truncated = False
        return obs, reward, done, truncated, info

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, context):
        attention_weights = torch.softmax(self.attention(context), dim=1)
        weighted_context = attention_weights * context
        return torch.sum(weighted_context, dim=1)

class TemporalFusionTransformer(TorchModelV2, nn.Module):
    @override(ModelV2)
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(TemporalFusionTransformer, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        print('action_space: ', action_space),
        self.order = list(sorted(self.action_space.spaces.keys()))
        self._last_decision_output = None
        # Retrieve custom model configuration
        num_assets =  model_config.get("num_assets", 3)
        custom_config = model_config.get("custom_model_config", {})
        news_embedding_shape = custom_config.get("news_embedding_shape")
        financial_data_shape = custom_config.get("financial_data_shape")
        hidden_size = custom_config.get("hidden_size", 256)
        lstm_hidden_size = custom_config.get("lstm_hidden_size", 256)
        dropout = custom_config.get("dropout", 0.4)
        self.lstm_hidden_size = lstm_hidden_size
        # Define layers and network architecture
        self.news_lstm = nn.LSTM(input_size=news_embedding_shape[2], hidden_size=hidden_size, batch_first=True)
        self.news_attention = Attention(hidden_size)
        self.financial_variable_selection = nn.Linear(financial_data_shape[2], 5)
        self.financial_encoder = nn.LSTM(input_size=5, hidden_size=hidden_size, batch_first=True)
        self.financial_gate = nn.Sequential(nn.Linear(5, 5), nn.Sigmoid())
        self.financial_attention = Attention(hidden_size)
        self.news_norm = nn.BatchNorm1d(hidden_size)
        self.financial_norm = nn.BatchNorm1d(hidden_size)
        self.dropout = nn.Dropout(p=dropout)

        # Output layers for different actions
        self.decoder_asset_selection = nn.Linear(lstm_hidden_size, num_assets)  # One-hot output for asset selection
        self.decoder_order_amount = nn.Linear(lstm_hidden_size, 2)  # Outputs mean and std for order amount
        self.decoder_order_cancellation = nn.Linear(lstm_hidden_size, 2)  # One-hot output for order cancellation

        self.value_head = nn.Linear(self.lstm_hidden_size, 1)
        self.lstm_state_size = lstm_hidden_size
        self.state = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        # Creating a __init__ function that acts as a passthrough and adding the warning
        # there led to errors probably due to the multiple inheritance. We encountered
        # the same error if we add the Deprecated decorator. We therefore add the
        # deprecation warning here.

        flat_inputs = input_dict["obs_flat"].float()
        # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
        # as input_dict may have extra zero-padding beyond seq_lens.max().
        # Use add_time_dimension to handle this
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
            flat_inputs,
            seq_lens=seq_lens,
            framework="torch",
            time_major=self.time_major,
        )
        output, new_state = self.forward_rnn(inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        return output, new_state

    @override(ModelV2)
    def get_initial_state(self, batch_size=None):
        if batch_size is None:
            batch_size = 1

        # Initialize hidden state (h) and cell state (c)
        h = torch.zeros(1, batch_size, self.lstm_hidden_size)
        c = torch.zeros(1, batch_size, self.lstm_hidden_size)
        return [h, c]

    def dict_action2preprocessed_action(self, action):
        """
        Preprocesses action components.
        - 'asset_selection': Convert logits to discrete actions.
        - 'order_amount': Use as is (mean and std).
        - 'order_cancellation': Convert logits to discrete actions.
        """
        # Process 'asset_selection'
        asset_selection =action['asset_selection'] #[batch_size, 3]

        # Use 'order_amount' as is (mean and std)
        order_amount = action['order_amount'] #[batch_size, 2]

        # Process 'order_cancellation'
        order_cancellation = action['order_cancellation'] #[batch_size, 2]

        print("asset_selection_logits shape:", asset_selection.shape)
        print("order_amount_output shape:", order_amount.shape)
        print("order_cancellation_logits shape:", order_cancellation.shape)
        flat_logits = torch.cat([asset_selection, order_amount, order_cancellation], dim=-1)
        print("Concatenated logits shape:", flat_logits.shape)

        return flat_logits

    def forward_rnn(self, inputs, state, seq_lens):
        # inputs: Tensor of shape [B, T, obs_size]
        #print('inputs', inputs)
        original_obs = restore_original_dimensions(inputs, self.obs_space.original_space, "torch")
        #print('original_obs: ', original_obs)
        #original_obs = restore_original_dimensions(inputs["obs"], self.obs_space.original_space, "[tf|torch]")
        news_embeddings_trough_time = original_obs['news_embeddings']
        financial_data_trough_time = original_obs['market_data']

        batch_size, seq_len, _ = inputs.size()
        outputs = []

        # Initialize or reset state according to the current batch size
        if self.state is None or self.state[0].size(1) != batch_size:
            self.state = self.get_initial_state(batch_size)

        for t in range(seq_len):
            input_t = inputs[:, t, :]
            # Reshape and split the flattened observation
            news_embeddings = news_embeddings_trough_time[:, t, :]
            financial_data = financial_data_trough_time[:, t, :]
            #print(f'news_embeddings @t={t}: ', news_embeddings)
            #print(f'financial_data @t={t}: ', financial_data)

            news_embeddings = news_embeddings.squeeze(2)
            #news_embeddings = input_t[:, :news_embedding_size].view(batch_size, -1, news_embedding_feature_size)
            #financial_data = input_t[:, news_embedding_size:].view(batch_size, num_assets, -1)
            # Process news context
            # Process and attend to news context
            news_output, _ = self.news_lstm(news_embeddings)
            news_context = self.news_attention(news_output)

            # Financial variable selection and gating
            selected_financial_data = self.financial_variable_selection(financial_data)
            gated_output = self.financial_gate(financial_data)


            # Multiplication operation
            gated_financial_data = gated_output * selected_financial_data


            batch_size, num_assets, seq_len, num_features = gated_financial_data.size()
            gated_financial_data = gated_financial_data.view(batch_size * num_assets, seq_len, num_features)


            # Process and attend to financial context
            financial_output, _ = self.financial_encoder(gated_financial_data)
            financial_context = self.financial_attention(financial_output)

            news_context = self.news_norm(news_context)
            financial_context = self.financial_norm(financial_context)

            news_context_flat = news_context.view(batch_size, -1)
            financial_context_flat = financial_context.view(batch_size, -1)
            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            fusion_input_size = news_context_flat.shape[1] + financial_context_flat.shape[1]

            # Print shapes for debugging
            print('Flattened news context shape:', news_context_flat.shape)
            print('Flattened financial context shape:', financial_context_flat.shape)
            print('Expected fusion input size:', fusion_input_size)

            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            print("Combined context size (before fusion gate):", combined_context.shape)

            # Dynamically initialize fusion_gate and decision_lstm if not done already
            if not hasattr(self, 'fusion_gate') or not hasattr(self, 'decision_lstm'):
                fusion_input_size = combined_context.shape[1]

                # Initialize fusion gate
                self.fusion_gate = nn.Sequential(
                    nn.Linear(fusion_input_size, fusion_input_size),
                    nn.Sigmoid()
                ).to(combined_context.device)  # Ensure the module is on the correct device

                # Initialize LSTM for decision-making
                self.decision_lstm = nn.LSTM(
                    input_size=fusion_input_size,
                    hidden_size=self.lstm_hidden_size,
                    batch_first=True
                ).to(combined_context.device)  # Ensure the module is on the correct device

            # Apply fusion gate
            gate = self.fusion_gate(combined_context)
            combined_context = gate * combined_context
            combined_context = self.dropout(combined_context)

            combined_context = combined_context.unsqueeze(1)  # Add time dimension
            print('shape of combined_context before decision_output : ', combined_context.shape)
            #print('shape of self.state before decision_output : ', self.state.shape)

            # LSTM for action decision
            decision_output, new_lstm_state = self.decision_lstm(combined_context, self.state)
            self.state = new_lstm_state

            decision_output = decision_output.squeeze(1)  # Remove the time dimension

            # Set _last_decision_output to the output of the LSTM layer
            self._last_decision_output = decision_output

            # Get logits for each action
            asset_selection_logits = self.decoder_asset_selection(decision_output)

            # Split the order_amount output into mean and std
            order_amount_output = self.decoder_order_amount(decision_output)
            order_amount_mean = order_amount_output[:, :1]  # First value for mean
            order_amount_std = nn.functional.softplus(order_amount_output[:, 1:])  # Second value for std (ensuring it's positive)

            # Concatenate mean and std for order_amount
            order_amount = torch.cat([order_amount_mean, order_amount_std], dim=-1)

            order_cancellation_logits = self.decoder_order_cancellation(decision_output)

            # Apply softmax to get probabilities (one-hot encoded format)
            asset_selection = torch.nn.functional.softmax(asset_selection_logits, dim=-1)
            order_cancellation = torch.nn.functional.softmax(order_cancellation_logits, dim=-1)

            # Construct action dictionary for this time step
            action_dict = {
                "asset_selection": asset_selection,
                "order_amount": order_amount,
                "order_cancellation": order_cancellation
            }

            # Append the action dictionary to the outputs list
        print('self.num_outputs: ', self.num_outputs)
        action = self.dict_action2preprocessed_action(action_dict)

        return action, state

    def value_function(self):
        return self.value_head(self._last_decision_output).squeeze(1)



ModelCatalog.register_custom_model("TemporalFusionTransformer", TemporalFusionTransformer)

# Define separate environment configurations for training and evaluation
env_config_training = {
    "test_env": False
}

env_config_evaluation = {
    "test_env": True
}

# Register the environment
register_env("DummyEnv", DummyEnv)

hyperparam_space = {
    "lr": 1e-5, #tune.loguniform(1e-5, 1e-2),
    "train_batch_size": 128, #128,
    "gamma": 0.9, #tune.uniform(0.9, 0.999),
    "exploration_config": {
        "type": "EpsilonGreedy",
        "initial_epsilon": 0.1, #tune.uniform(0.1, 1.0),
        "final_epsilon": 0.01,#tune.uniform(0.01, 0.1),
        "epsilon_timesteps": 10000,#tune.choice([10000, 20000, 50000]),
    },

    "model": {
        "custom_model": "TemporalFusionTransformer",
        "custom_model_config": {
            "news_embedding_shape": (10, 1, 221),
            "financial_data_shape": (3, 100, 5),
            "hidden_size": 64, #tune.choice([32, 64, 128]),
            "dropout": 0.1, #tune.uniform(0.1, 0.5),
            "num_lstm_layers": 1, #tune.choice([1, 2, 3]),
            "lstm_hidden_size": 32, #tune.choice([32, 64, 128]),
            "num_assets": 3,
        }
    }
}

# Configuration for PPO with your custom model and hyperparameter space
config = {
    "framework": "torch",
    "env": "DummyEnv",
    "num_workers": 0,
    "num_envs_per_worker": 1,
    "rollout_fragment_length": 128,
    "batch_mode": "truncate_episodes",
    **hyperparam_space
}

torch.set_num_threads(1)
ray.init(num_cpus=1)

analysis = tune.run(
    "PPO",  # Or your chosen algorithm
    config=config
)
ray.tune.error.TuneError: ('Trials did not complete', [PPO_DummyEnv_93348_00000])
(PPO pid=11876) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=11876, ip=127.0.0.1, actor_id=8a655176ffec8fee7bc93a5e01000000, repr=PPO)
(PPO pid=11876)   File "python\ray\_raylet.pyx", line 1813, in ray._raylet.execute_task
(PPO pid=11876)   File "python\ray\_raylet.pyx", line 1754, in ray._raylet.execute_task.function_executor
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\_private\function_manager.py", line 726, in actor_method_executor
(PPO pid=11876)     return method(__ray_actor, *args, **kwargs)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=11876)     return method(self, *_args, **_kwargs)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 516, in __init__
(PPO pid=11876)     super().__init__(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\tune\trainable\trainable.py", line 161, in __init__
(PPO pid=11876)     self.setup(copy.deepcopy(self.config))
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=11876)     return method(self, *_args, **_kwargs)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 638, in setup
(PPO pid=11876)     self.workers = WorkerSet(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 159, in __init__
(PPO pid=11876)     self._setup(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 249, in _setup
(PPO pid=11876)     self._local_worker = self._make_worker(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 950, in _make_worker
(PPO pid=11876)     worker = cls(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=11876)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1746, in _update_policy_map
(PPO pid=11876)     self._build_policy_map(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1857, in _build_policy_map
(PPO pid=11876)     new_policy = create_policy_for_framework(
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=11876)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=11876)     self._initialize_loss_from_dummy_batch()
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\policy\policy.py", line 1520, in _initialize_loss_from_dummy_batch
(PPO pid=11876)     self.loss(self.model, self.dist_class, train_batch)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=11876)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 553, in logp
(PPO pid=11876)     flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\tree\__init__.py", line 435, in map_structure
(PPO pid=11876)     [func(*args) for args in zip(*map(flatten, structures))])
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\tree\__init__.py", line 435, in <listcomp>
(PPO pid=11876)     [func(*args) for args in zip(*map(flatten, structures))])
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 549, in map_
(PPO pid=11876)     return dist.logp(val)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=11876)     return self.dist.log_prob(actions)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\torch\distributions\categorical.py", line 137, in log_prob
(PPO pid=11876)     self._validate_sample(value)
(PPO pid=11876)   File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\torch\distributions\distribution.py", line 297, in _validate_sample
(PPO pid=11876)     raise ValueError(
(PPO pid=11876) ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32]) vs torch.Size([4]).

In this setup, num_outputs is 7.
I have checked the shape of the action returned by the forward_rnn fonction and it seems it is [batch_size, 7] wich is correct, if i understand correctly.

I would greatly appreciate any guidance or suggestions on how to resolve this error. It seems to be related to the shape of the actions being returned by the model, but I am unsure how to correctly format them for the PPO algorithm.

Thank you in advance for your help!

I understood that the output of the forward function should have shape [batch_size, timesteps, num_outputs] instead of [batch_size, num_outputs].

Also I think the value function should return a tensor of shape [batch_size, 1].

Here is the updated code.

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import ray
from ray import tune
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.algorithms.ppo import PPOConfig
from gymnasium.spaces import Box, Discrete, Dict
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.models.modelv2 import restore_original_dimensions

# Simplified Dummy Environment
class DummyEnv(gym.Env):
    def __init__(self, config):
        super(DummyEnv, self).__init__()
        self.observation_space = Dict({
            'market_data': Box(low=0, high=1, shape=(3, 100, 5), dtype=np.float32),
            'news_embeddings': Box(low=-np.inf, high=np.inf, shape=(10, 1, 221), dtype=np.float32)
        })
        self.action_space = Dict({
            "asset_selection": Discrete(3),
            "order_amount": Box(low=-1, high=1, shape=(1,)),
            "order_cancellation": Discrete(2)
        })

    def reset(self, *, seed=None, options=None):
        return {
            'market_data': np.random.rand(*self.observation_space['market_data'].shape),
            'news_embeddings': np.random.rand(*self.observation_space['news_embeddings'].shape)
        }, {}

    def step(self, action):
        obs = self.reset()[0]
        reward = np.random.rand()
        done = False
        info = {}
        truncated = False
        return obs, reward, done, truncated, info

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, context):
        attention_weights = torch.softmax(self.attention(context), dim=1)
        weighted_context = attention_weights * context
        return torch.sum(weighted_context, dim=1)

class TemporalFusionTransformer(TorchModelV2, nn.Module):
    @override(ModelV2)
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(TemporalFusionTransformer, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        print('action_space: ', action_space),
        self.order = list(sorted(self.action_space.spaces.keys()))
        self._last_decision_output = None
        # Retrieve custom model configuration
        num_assets =  model_config.get("num_assets", 3)
        custom_config = model_config.get("custom_model_config", {})
        news_embedding_shape = custom_config.get("news_embedding_shape")
        financial_data_shape = custom_config.get("financial_data_shape")
        hidden_size = custom_config.get("hidden_size", 256)
        lstm_hidden_size = custom_config.get("lstm_hidden_size", 256)
        dropout = custom_config.get("dropout", 0.4)
        self.lstm_hidden_size = lstm_hidden_size
        # Define layers and network architecture
        self.news_lstm = nn.LSTM(input_size=news_embedding_shape[2], hidden_size=hidden_size, batch_first=True)
        self.news_attention = Attention(hidden_size)
        self.financial_variable_selection = nn.Linear(financial_data_shape[2], 5)
        self.financial_encoder = nn.LSTM(input_size=5, hidden_size=hidden_size, batch_first=True)
        self.financial_gate = nn.Sequential(nn.Linear(5, 5), nn.Sigmoid())
        self.financial_attention = Attention(hidden_size)
        self.news_norm = nn.BatchNorm1d(hidden_size)
        self.financial_norm = nn.BatchNorm1d(hidden_size)
        self.dropout = nn.Dropout(p=dropout)

        # Output layers for different actions
        self.decoder_asset_selection = nn.Linear(lstm_hidden_size, num_assets)  # One-hot output for asset selection
        self.decoder_order_amount = nn.Linear(lstm_hidden_size, 2)  # Outputs mean and std for order amount
        self.decoder_order_cancellation = nn.Linear(lstm_hidden_size, 2)  # One-hot output for order cancellation

        self.value_head = nn.Linear(self.lstm_hidden_size, 1)
        self.lstm_state_size = lstm_hidden_size
        self.state = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        # Creating a __init__ function that acts as a passthrough and adding the warning
        # there led to errors probably due to the multiple inheritance. We encountered
        # the same error if we add the Deprecated decorator. We therefore add the
        # deprecation warning here.

        flat_inputs = input_dict["obs_flat"].float()
        # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
        # as input_dict may have extra zero-padding beyond seq_lens.max().
        # Use add_time_dimension to handle this
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
            flat_inputs,
            seq_lens=seq_lens,
            framework="torch",
            time_major=self.time_major,
        )
        output, new_state = self.forward_rnn(inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        return output, new_state

    @override(ModelV2)
    def get_initial_state(self, batch_size=None):
        if batch_size is None:
            batch_size = 1

        # Initialize hidden state (h) and cell state (c)
        h = torch.zeros(1, batch_size, self.lstm_hidden_size)
        c = torch.zeros(1, batch_size, self.lstm_hidden_size)
        return [h, c]

    def dict_action2preprocessed_action(self, action):
        """
        Preprocesses action components.
        - 'asset_selection': Convert logits to discrete actions.
        - 'order_amount': Use as is (mean and std).
        - 'order_cancellation': Convert logits to discrete actions.
        """
        # Process 'asset_selection'
        asset_selection =action['asset_selection'] #[batch_size, 3]

        # Use 'order_amount' as is (mean and std)
        order_amount = action['order_amount'] #[batch_size, 2]

        # Process 'order_cancellation'
        order_cancellation = action['order_cancellation'] #[batch_size, 2]

        print("asset_selection_logits shape:", asset_selection.shape)
        print("order_amount_output shape:", order_amount.shape)
        print("order_cancellation_logits shape:", order_cancellation.shape)
        flat_logits = torch.cat([asset_selection, order_amount, order_cancellation], dim=-1)
        print("Concatenated logits shape:", flat_logits.shape)

        return flat_logits

    def forward_rnn(self, inputs, state, seq_lens):
        # inputs: Tensor of shape [B, T, obs_size]
        #print('inputs', inputs)
        original_obs = restore_original_dimensions(inputs, self.obs_space.original_space, "torch")
        #print('original_obs: ', original_obs)
        #original_obs = restore_original_dimensions(inputs["obs"], self.obs_space.original_space, "[tf|torch]")
        news_embeddings_trough_time = original_obs['news_embeddings']
        financial_data_trough_time = original_obs['market_data']

        batch_size, seq_len, _ = inputs.size()
        print('batch_size: ', batch_size)
        input('press a key to continue...')
        timestep_outputs = []  # List to store outputs at each timestep

        # Initialize or reset state according to the current batch size
        if self.state is None or self.state[0].size(1) != batch_size:
            self.state = self.get_initial_state(batch_size)

        for t in range(seq_len):
            input_t = inputs[:, t, :]
            print("Input tensor shape at time t:", input_t.shape)

            # Reshape and split the flattened observation
            news_embeddings = news_embeddings_trough_time[:, t, :]
            financial_data = financial_data_trough_time[:, t, :]
            #print(f'news_embeddings @t={t}: ', news_embeddings)
            #print(f'financial_data @t={t}: ', financial_data)

            news_embeddings = news_embeddings.squeeze(2)
            #news_embeddings = input_t[:, :news_embedding_size].view(batch_size, -1, news_embedding_feature_size)
            #financial_data = input_t[:, news_embedding_size:].view(batch_size, num_assets, -1)
            # Process news context
            # Process and attend to news context
            news_output, _ = self.news_lstm(news_embeddings)
            news_context = self.news_attention(news_output)

            # Financial variable selection and gating
            selected_financial_data = self.financial_variable_selection(financial_data)
            gated_output = self.financial_gate(financial_data)


            # Multiplication operation
            gated_financial_data = gated_output * selected_financial_data


            batch_size, num_assets, seq_len, num_features = gated_financial_data.size()
            gated_financial_data = gated_financial_data.view(batch_size * num_assets, seq_len, num_features)


            # Process and attend to financial context
            financial_output, _ = self.financial_encoder(gated_financial_data)
            financial_context = self.financial_attention(financial_output)

            news_context = self.news_norm(news_context)
            financial_context = self.financial_norm(financial_context)

            news_context_flat = news_context.view(batch_size, -1)
            financial_context_flat = financial_context.view(batch_size, -1)
            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            fusion_input_size = news_context_flat.shape[1] + financial_context_flat.shape[1]

            # Print shapes for debugging
            print('Flattened news context shape:', news_context_flat.shape)
            print('Flattened financial context shape:', financial_context_flat.shape)
            print('Expected fusion input size:', fusion_input_size)

            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            print("Combined context size (before fusion gate):", combined_context.shape)

            # Dynamically initialize fusion_gate and decision_lstm if not done already
            if not hasattr(self, 'fusion_gate') or not hasattr(self, 'decision_lstm'):
                fusion_input_size = combined_context.shape[1]

                # Initialize fusion gate
                self.fusion_gate = nn.Sequential(
                    nn.Linear(fusion_input_size, fusion_input_size),
                    nn.Sigmoid()
                ).to(combined_context.device)  # Ensure the module is on the correct device

                # Initialize LSTM for decision-making
                self.decision_lstm = nn.LSTM(
                    input_size=fusion_input_size,
                    hidden_size=self.lstm_hidden_size,
                    batch_first=True
                ).to(combined_context.device)  # Ensure the module is on the correct device

            # Apply fusion gate
            gate = self.fusion_gate(combined_context)
            combined_context = gate * combined_context
            combined_context = self.dropout(combined_context)

            combined_context = combined_context.unsqueeze(1)  # Add time dimension
            print('shape of combined_context before decision_output : ', combined_context.shape)
            #print('shape of self.state before decision_output : ', self.state.shape)

            # LSTM for action decision
            decision_output, new_lstm_state = self.decision_lstm(combined_context, self.state)
            self.state = new_lstm_state

            decision_output = decision_output.squeeze(1)  # Remove the time dimension
            print("Decision output shape:", decision_output.shape)

            # Set _last_decision_output to the output of the LSTM layer
            self._last_decision_output = decision_output

            # Get logits for each action
            asset_selection_logits = self.decoder_asset_selection(decision_output)

            # Split the order_amount output into mean and std
            order_amount_output = self.decoder_order_amount(decision_output)
            order_amount_mean = order_amount_output[:, :1]  # First value for mean
            order_amount_std = nn.functional.softplus(order_amount_output[:, 1:])  # Second value for std (ensuring it's positive)

            # Concatenate mean and std for order_amount
            order_amount = torch.cat([order_amount_mean, order_amount_std], dim=-1)

            order_cancellation_logits = self.decoder_order_cancellation(decision_output)

            # Apply softmax to get probabilities (one-hot encoded format)
            asset_selection = torch.nn.functional.softmax(asset_selection_logits, dim=-1)
            order_cancellation = torch.nn.functional.softmax(order_cancellation_logits, dim=-1)

            # Construct action dictionary for this time step
            action_dict = {
                "asset_selection": asset_selection,
                "order_amount": order_amount,
                "order_cancellation": order_cancellation
            }

            # Append the action dictionary to the outputs list
            print('self.num_outputs: ', self.num_outputs)
            action = self.dict_action2preprocessed_action(action_dict)
            timestep_outputs.append(action)


        return torch.stack(timestep_outputs, dim=1), state

    def value_function(self):
        print('self.value_head(self._last_decision_output).squeeze(1): ', self.value_head(self._last_decision_output).shape)
        return self.value_head(self._last_decision_output)



ModelCatalog.register_custom_model("TemporalFusionTransformer", TemporalFusionTransformer)

# Define separate environment configurations for training and evaluation
env_config_training = {
    "test_env": False
}

env_config_evaluation = {
    "test_env": True
}

# Register the environment
register_env("DummyEnv", DummyEnv)

hyperparam_space = {
    "lr": 1e-5, #tune.loguniform(1e-5, 1e-2),
    "train_batch_size": 128, #128,
    "gamma": 0.9, #tune.uniform(0.9, 0.999),
    "exploration_config": {
        "type": "EpsilonGreedy",
        "initial_epsilon": 0.1, #tune.uniform(0.1, 1.0),
        "final_epsilon": 0.01,#tune.uniform(0.01, 0.1),
        "epsilon_timesteps": 10000,#tune.choice([10000, 20000, 50000]),
    },

    "model": {
        "custom_model": "TemporalFusionTransformer",
        "custom_model_config": {
            "news_embedding_shape": (10, 1, 221),
            "financial_data_shape": (3, 100, 5),
            "hidden_size": 64, #tune.choice([32, 64, 128]),
            "dropout": 0.1, #tune.uniform(0.1, 0.5),
            "num_lstm_layers": 1, #tune.choice([1, 2, 3]),
            "lstm_hidden_size": 32, #tune.choice([32, 64, 128]),
            "num_assets": 3,
        }
    }
}

# Configuration for PPO with your custom model and hyperparameter space
config = {
    "framework": "torch",
    "env": "DummyEnv",
    "num_workers": 0,
    "num_envs_per_worker": 1,
    "rollout_fragment_length": 128,
    "batch_mode": "truncate_episodes",
    **hyperparam_space
}

torch.set_num_threads(1)
ray.init(num_cpus=1)

analysis = tune.run(
    "PPO",  # Or your chosen algorithm
    config=config
)

With this code, i get an error later in the ‘initialize_loss_from_dummy_batch’ function :

self._update_policy_map(policy_dict=self.policy_dict)
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1746, in _update_policy_map
    self._build_policy_map(
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1857, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\policy\policy.py", line 1481, in _initialize_loss_from_dummy_batch
    postprocessed_batch = self.postprocess_trajectory(self._dummy_batch)
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 219, in postprocess_trajectory
    return compute_gae_for_sample_batch(
  File "C:\Users\lesei\anaconda3_2\envs\newshunter\lib\site-packages\ray\rllib\evaluation\postprocessing.py", line 195, in compute_gae_for_sample_batch
    assert vf_preds.shape == rewards.shap

I changed the value_function to return a [batch_size * timesteps] tensor and the training now starts.

    def forward_rnn(self, inputs, state, seq_lens):
        # inputs: Tensor of shape [B, T, obs_size]
        ##print(('inputs', inputs)
        original_obs = restore_original_dimensions(inputs, self.obs_space.original_space, "torch")
        ##print(('original_obs: ', original_obs)
        #original_obs = restore_original_dimensions(inputs["obs"], self.obs_space.original_space, "[tf|torch]")
        news_embeddings_trough_time = original_obs['news_embeddings']
        financial_data_trough_time = original_obs['market_data']

        batch_size, seq_len, _ = inputs.size()
        #print(('batch_size: ', batch_size)
        #input('press a key to continue...')
        timestep_outputs = []  # List to store outputs at each timestep
        decision_outputs = []
        # Initialize or reset state according to the current batch size
        if self.state is None or self.state[0].size(1) != batch_size:
            self.state = self.get_initial_state(batch_size)

        for t in range(seq_len):
            input_t = inputs[:, t, :]
            #print(("Input tensor shape at time t:", input_t.shape)

            # Reshape and split the flattened observation
            news_embeddings = news_embeddings_trough_time[:, t, :]
            financial_data = financial_data_trough_time[:, t, :]
            ##print((f'news_embeddings @t={t}: ', news_embeddings)
            ##print((f'financial_data @t={t}: ', financial_data)

            news_embeddings = news_embeddings.squeeze(2)
            #news_embeddings = input_t[:, :news_embedding_size].view(batch_size, -1, news_embedding_feature_size)
            #financial_data = input_t[:, news_embedding_size:].view(batch_size, num_assets, -1)
            # Process news context
            # Process and attend to news context
            news_output, _ = self.news_lstm(news_embeddings)
            news_context = self.news_attention(news_output)

            # Financial variable selection and gating
            selected_financial_data = self.financial_variable_selection(financial_data)
            gated_output = self.financial_gate(financial_data)


            # Multiplication operation
            gated_financial_data = gated_output * selected_financial_data


            batch_size, num_assets, seq_len, num_features = gated_financial_data.size()
            gated_financial_data = gated_financial_data.view(batch_size * num_assets, seq_len, num_features)


            # Process and attend to financial context
            financial_output, _ = self.financial_encoder(gated_financial_data)
            financial_context = self.financial_attention(financial_output)

            news_context = self.news_norm(news_context)
            financial_context = self.financial_norm(financial_context)

            news_context_flat = news_context.view(batch_size, -1)
            financial_context_flat = financial_context.view(batch_size, -1)
            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            fusion_input_size = news_context_flat.shape[1] + financial_context_flat.shape[1]

            # #print( shapes for debugging
            #print(('Flattened news context shape:', news_context_flat.shape)
            #print(('Flattened financial context shape:', financial_context_flat.shape)
            #print(('Expected fusion input size:', fusion_input_size)

            combined_context = torch.cat([news_context_flat, financial_context_flat], dim=1)
            #print(("Combined context size (before fusion gate):", combined_context.shape)

            # Dynamically initialize fusion_gate and decision_lstm if not done already
            if not hasattr(self, 'fusion_gate') or not hasattr(self, 'decision_lstm'):
                fusion_input_size = combined_context.shape[1]

                # Initialize fusion gate
                self.fusion_gate = nn.Sequential(
                    nn.Linear(fusion_input_size, fusion_input_size),
                    nn.Sigmoid()
                ).to(combined_context.device)  # Ensure the module is on the correct device

                # Initialize LSTM for decision-making
                self.decision_lstm = nn.LSTM(
                    input_size=fusion_input_size,
                    hidden_size=self.lstm_hidden_size,
                    batch_first=True
                ).to(combined_context.device)  # Ensure the module is on the correct device

            # Apply fusion gate
            gate = self.fusion_gate(combined_context)
            combined_context = gate * combined_context
            combined_context = self.dropout(combined_context)

            combined_context = combined_context.unsqueeze(1)  # Add time dimension
            #print(('shape of combined_context before decision_output : ', combined_context.shape)
            ##print(('shape of self.state before decision_output : ', self.state.shape)

            # LSTM for action decision
            decision_output, new_lstm_state = self.decision_lstm(combined_context, self.state)
            decision_outputs.append(decision_output)
            self.state = new_lstm_state

            decision_output = decision_output.squeeze(1)  # Remove the time dimension
            #print(("Decision output shape:", decision_output.shape)

            # Set _last_decision_output to the output of the LSTM layer
            self._last_decision_output = decision_output

            # Get logits for each action
            asset_selection_logits = self.decoder_asset_selection(decision_output)

            # Split the order_amount output into mean and std
            order_amount_output = self.decoder_order_amount(decision_output)
            order_amount_mean = order_amount_output[:, :1]  # First value for mean
            order_amount_std = nn.functional.softplus(order_amount_output[:, 1:])  # Second value for std (ensuring it's positive)

            # Concatenate mean and std for order_amount
            order_amount = torch.cat([order_amount_mean, order_amount_std], dim=-1)

            order_cancellation_logits = self.decoder_order_cancellation(decision_output)

            # Apply softmax to get probabilities (one-hot encoded format)
            asset_selection = torch.nn.functional.softmax(asset_selection_logits, dim=-1)
            order_cancellation = torch.nn.functional.softmax(order_cancellation_logits, dim=-1)

            # Construct action dictionary for this time step
            action_dict = {
                "asset_selection": asset_selection,
                "order_amount": order_amount,
                "order_cancellation": order_cancellation
            }

            # Append the action dictionary to the outputs list
            #print(('self.num_outputs: ', self.num_outputs)
            action = self.dict_action2preprocessed_action(action_dict)
            timestep_outputs.append(action)

        self._last_decision_outputs = decision_outputs
        return torch.stack(timestep_outputs, dim=1), state

    def value_function(self):
        if self._last_decision_outputs is None or len(self._last_decision_outputs) == 0:
            raise ValueError("No decision outputs stored. Call forward_rnn first.")

        # Assuming _last_decision_outputs is a list of decision outputs at each timestep
        # Compute value for each timestep
        values_per_timestep = [self.value_head(decision_output).squeeze(-1) for decision_output in self._last_decision_outputs]

        # Stack to get a [B, T] tensor and then flatten to [B * T]
        values_flattened = torch.cat(values_per_timestep, dim=1).view(-1)

        return values_flattened