How to use LSTM or Attention Network action masking with nested dict action space?

    self.discrete_features_count = 7 # 23 # 4 strategies + close + hold + adjust hedge #remove  * 5 hedge sizes
    self.trade_sizes_actions_count = 21
    self.max_days_to_expiry = 200

    # Define the action space
    #This code will ensure that 'type' is within the range of 0 to num_actions (inclusive), 
    # 'trade_size' is within the range of 0.01 to 1.0 (inclusive), and 'days_to_expiry' is within the range of 0 to self.max_days_to_expiry (inclusive).
    self.action_space = spaces.Dict({
        'type': spaces.Discrete(self.discrete_features_count),  # specific action type
        'trade_size': spaces.Discrete(self.trade_sizes_actions_count),  # trade size (from 1% to 100% with 5% step size)
        'days_to_expiry': spaces.Discrete(self.max_days_to_expiry)  # days to expiry

    self.total_actions = self.discrete_features_count * self.trade_sizes_actions_count * self.max_days_to_expiry

    self.current_action_mask = {
        'type': np.ones(self.discrete_features_count, dtype=np.int8),
        'trade_size': np.ones(self.trade_sizes_actions_count, dtype=np.int8),
        'days_to_expiry': np.ones(self.max_days_to_expiry, dtype=np.int8),

    self.observation_space = spaces.Dict({
        "observations": spaces.Box(low=-np.inf, high=np.inf, shape=(self.window_size, self.num_features), dtype=np.float32),
        "action_mask": spaces.Dict({
            'type': spaces.Box(low=0, high=1, shape=(self.action_space['type'].n,), dtype=np.int8),
            'trade_size': spaces.Box(low=0, high=1, shape=(self.action_space['trade_size'].n,), dtype=np.int8),
            'days_to_expiry': spaces.Box(low=0, high=1, shape=(self.action_space['days_to_expiry'].n,), dtype=np.int8),

I cant get the forward function to correctly extract the masked actions as the nested action dict seems to be causing the issue?

def forward(self, input_dict, state, seq_lens):
    # Extract actual observations and action masks
    observations = input_dict["obs"]["observations"].float()
    flattened_obs = observations.view(200, -1)

    # Extract obs_flat from the input_dict
    obs_flat = input_dict["obs_flat"].float()
    # Extracting action masks
    action_mask_type = input_dict["obs"]["action_mask"]["type"].float()
    action_mask_trade_size = input_dict["obs"]["action_mask"]["trade_size"].float()
    action_mask_days_to_expiry = input_dict["obs"]["action_mask"]["days_to_expiry"].float()
    # Combine the separate masks into a single mask tensor
    combined_mask =[action_mask_type, action_mask_trade_size, action_mask_days_to_expiry], dim=-1)
    # Concatenate the flattened_obs with the combined_mask
    obs_with_mask =[flattened_obs, combined_mask], dim=-1)
    # Use obs_with_mask as input to your model
    logits, _ = self.torch_sub_model({"obs": obs_with_mask}, state, seq_lens)
    #logits, _ = self.torch_sub_model({"obs": input_dict["obs"]["observations"]})
    # Convert combined_mask into a [0.0 || -inf]-type mask
    inf_mask = torch.clamp(torch.log(combined_mask), min=FLOAT_MIN)
    # Create a tensor filled with zeros of shape [200, 256]
    complete_inf_mask = torch.zeros_like(logits)

    # Extract the dimension dynamically
    mask_dim = inf_mask.shape[1]

    # Apply the original inf_mask values to the relevant positions in the complete_inf_mask tensor
    complete_inf_mask[:, :mask_dim] = inf_mask
    # Add the logits and the complete_inf_mask tensor
    masked_logits = logits + complete_inf_mask
    return masked_logits, state

(APPO pid=1304387) File “/home/ray/anaconda3/lib/python3.8/site-packages/ray/rllib/models/torch/”, line 255, in forward
(APPO pid=1304387) wrapped_out =[wrapped_out] + prev_a_r, dim=1)
(APPO pid=1304387) RuntimeError: Tensors must have same number of dimensions: got 2 and 3

this is the closest i can get

any help?