How to use Custom Model in MultiAgent PPO Policy

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity
  • Low: It annoys or frustrates me for a moment.
  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.
  • High: It blocks me to complete my task.

Hello, I am implementing a custom model in MultiAgent PPO, and I meet some trouble.

import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from env import RLlibSMAC
from ray import air, tune
from ray.tune.registry import register_env
from model import TorchTransformerActionMaskModel

def policy_mapping_func(agent_id, *args, **kwargs):
    return agent_id

if __name__ == "__main__":
    ray.init()
    smac_env = RLlibSMAC('3m')
    env_info = smac_env.get_env_info()

    def env_creator(args):
        return RLlibSMAC('3m')
    
    register_env('smac_3m', env_creator)

    ModelCatalog.register_custom_model(
        "TorchMaskedModel",
        TorchTransformerActionMaskModel
    )

    config = (
        PPOConfig()
        .environment('smac_3m')
        .framework('torch')
        .rollouts(num_envs_per_worker=10, batch_mode="complete_episodes")
        .training(
            model={
                # "use_attention": True,
                # "attention_dim": 30,
                # "attention_num_transformer_units": 6,
                # 'attention_num_heads': 10, 
                "custom_model": "TorchMaskedModel",
                "custom_model_config": {
                    'attention_num_transformer_units': 6, 
                    'attention_dim': 30, 
                    'attention_num_heads': 2, 
                    'attention_head_dim': 64, 
                    'attention_memory_inference': 50, 
                    'attention_memory_training': 50, 
                    'attention_position_wise_mlp_dim': 32, 
                    'attention_init_gru_gate_bias': 2.0, 
                    'attention_use_n_prev_actions': 0, 
                    'attention_use_n_prev_rewards': 0,
                    'max_seq_len': 128,
                    'fcnet_hiddens': [256, 128, 9], 
                    'fcnet_activation': 'tanh',
                },
            },
        )
        .resources(num_gpus=0)
        .multi_agent(
            policies = {"agent_{}".format(i): (None, env_info["observation_space"], env_info["action_space"], {}) for i in range(env_info["num_agents"])},
            policy_mapping_fn=policy_mapping_func,
        ).
        rl_module(_enable_rl_module_api=False)
    )

    stop = {"training_iteration": 100,
            "timesteps_total": 5000000,}

    tune.Tuner(
        'PPO',
        run_config = air.RunConfig(
            stop = stop
        ),
        param_space = config.to_dict()
    ).fit()


(The commented out part is the default model)
When I use the default model and set use_attention=Ture (the environment has invalid actions so I just add mask in AttentionWrapper) ,it works fine
(AttentionWrapper in ray/attention_net.py at master · ray-project/ray · GitHub)
(The added mask like this:ray/action_masking.py at master · ray-project/ray · GitHub)
When I replaced the model with a custom model (which is almost the same as AttentionWrapper), an error occurred.

File "/home/anaconda3/envs/SmacEnv/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 259, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/SmacTest/model.py", line 276, in forward
    self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
  File "/home/anaconda3/envs/SmacEnv/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 259, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/anaconda3/envs/SmacEnv/lib/python3.9/site-packages/ray/rllib/models/torch/attention_net.py", line 214, in forward
    T = observations.shape[0] // B
AttributeError: 'collections.OrderedDict' object has no attribute 'shape'

My custom model show as behind

from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.torch.attention_net import GTrXLNet
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, MultiDiscrete
import numpy as np
import tree  # pip install dm_tree
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC


from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
from typing import Dict, Optional, Union

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()



class TorchTransformerActionMaskModel(TorchModelV2, nn.Module):
  
    """GTrXL wrapper serving as interface for ModelV2s that set use_attention."""

    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
    ):

        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, None, model_config, name)

        self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
        self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]

        self.action_space_struct = get_base_struct_from_space(self.action_space)
        self.action_dim = 0

        
        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
            else:
                self.action_dim += int(len(space))

        # Add prev-action/reward nodes to input to LSTM.
        if self.use_n_prev_actions:
            self.num_outputs += self.use_n_prev_actions * self.action_dim
        if self.use_n_prev_rewards:
            self.num_outputs += self.use_n_prev_rewards

        cfg = model_config
        self.attention_dim = cfg["custom_model_config"]["attention_dim"]
        # self.attention_dim = cfg["attention_dim"]

        if self.num_outputs is not None:
            in_space = gym.spaces.Box(
                float("-inf"), float("inf"), shape=(self.num_outputs,), dtype=np.float32
            )
        else:
            in_space = obs_space

        self.gtrxl = GTrXLNet(
            in_space,
            action_space,
            None,
            model_config,
            "gtrxl",
            num_transformer_units=cfg["custom_model_config"]["attention_num_transformer_units"],
            attention_dim=cfg["custom_model_config"]["attention_dim"],
            num_heads=cfg["custom_model_config"]["attention_num_heads"],
            head_dim=cfg["custom_model_config"]["attention_head_dim"],
            memory_inference=cfg["custom_model_config"]["attention_memory_inference"],
            memory_training=cfg["custom_model_config"]["attention_memory_training"],
            position_wise_mlp_dim=cfg["custom_model_config"]["attention_position_wise_mlp_dim"],
            init_gru_gate_bias=cfg["custom_model_config"]["attention_init_gru_gate_bias"],
        )

        # Set final num_outputs to correct value (depending on action space).
        self.num_outputs = num_outputs

        # Postprocess GTrXL output with another hidden layer and compute
        # values.
        self._logits_branch = SlimFC(
            in_size=self.attention_dim,
            out_size=self.num_outputs,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )
        self._value_branch = SlimFC(
            in_size=self.attention_dim,
            out_size=1,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )

        self.view_requirements = self.gtrxl.view_requirements
        self.view_requirements["obs"].space = self.obs_space

        # Add prev-a/r to this model's view, if required.
        if self.use_n_prev_actions:
            self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
                SampleBatch.ACTIONS,
                space=self.action_space,
                shift="-{}:-1".format(self.use_n_prev_actions),
            )
        if self.use_n_prev_rewards:
            self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
                SampleBatch.REWARDS, shift="-{}:-1".format(self.use_n_prev_rewards)
            )

    @override(RecurrentNetwork)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        # wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
        action_mask = input_dict["obs"]["action_mask"]
        # Concat. prev-action/reward if required.

        # Prev actions.

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        model_out = model_out + inf_mask
        return model_out, memory_outs

    @override(ModelV2)
    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
        return [
            torch.zeros(
                self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape
            )
            for i in range(self.gtrxl.num_transformer_units)
        ]

    @override(ModelV2)
    def value_function(self) -> TensorType:
        assert self._features is not None, "Must call forward() first!"
        return torch.reshape(self._value_branch(self._features), [-1])


I tried to fix the problem, and by comparing AttenWrapper and custom model inputs, I found a difference:

SampleBatch(32 (seqs=4): ['obs', 'new_obs', 'actions', 'prev_actions', 'rewards', 'prev_rewards', 'terminateds', 'truncateds', 'infos', 'eps_id', 'unroll_id', 'agent_index', 't', 'state_in_0', 'state_out_0', 'state_in_1', 'state_out_1', 'state_in_2', 'state_out_2', 'state_in_3', 'state_out_3', 'state_in_4', 'state_out_4', 'state_in_5', 'state_out_5', 'vf_preds', 'action_dist_inputs', 'action_prob', 'action_logp', 'advantages', 'value_targets', 'obs_flat'])
SampleBatch(32: ['obs', 'new_obs', 'actions', 'prev_actions', 'rewards', 'prev_rewards', 'terminateds', 'truncateds', 'infos', 'eps_id', 'unroll_id', 'agent_index', 't', 'state_in_0', 'state_out_0', 'state_in_1', 'state_out_1', 'state_in_2', 'state_out_2', 'state_in_3', 'state_out_3', 'state_in_4', 'state_out_4', 'state_in_5', 'state_out_5', 'obs_flat'])

The top is the input_dict of the default model, and the bottom is the input_dict of the custom model. The custom model input_dict does not have seqs and the ‘vf_preds’ after it, ‘vf_preds’, ‘action_dist_inputs’, ‘action_prob’, ‘action_logp’, ‘advantages’, ‘value_targets’
So I think maybe I need define something extra in tune? but I don’t know what should I define.Thank you all for your help.

Usually, you’d want your custom model to be wrapped by the attention wrapper.
So you define your own model like you would otherwise and simply set use_attention=True.

Thanks for your reply, I solved my problem by copying the code of GTRXL from attention_net.py to custom_model and modifying it.

@pkgunboat I am facing the same problem. Could you give some insights about what you changed in the GTRXL class?