Custome the SACTorchModel

I need to apply some custom changes to the SACTorchModel. There are some changes I want to make for actor network and critic network in masac algorithm. However, I dont know how to change it by adding the custome_model, especially for multi-agent system environment. My test code is like this. (It runs with errors under the customSACmodel framework)

Thanks!

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
import ray
import gymnasium as gym
from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
from ray.rllib.algorithms.sac import SACConfig

class CustomSACModel(TorchModelV2, nn.Module):
    """自定义 SAC 模型,包含策略网络和两个 Q 网络。"""
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        super(CustomSACModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        self.obs_dim = obs_space.shape[0]
        self.act_dim = action_space.shape[0]
        
        # 策略网络(Actor)
        self.actor_net = nn.Sequential(
            nn.Linear(self.obs_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, self.act_dim * 2)  # 输出动作的均值和 log_std
        )
        
        # 第一个 Q 网络(Critic 1)
        self.q_net1 = nn.Sequential(
            nn.Linear(self.obs_dim + self.act_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, 1)
        )

        # 第二个 Q 网络(Critic 2)
        self.q_net2 = nn.Sequential(
            nn.Linear(self.obs_dim + self.act_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, 1)
        )

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"]
        mean, log_std = self.actor_net(obs).chunk(2, dim=-1)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        return mean, std

    def get_q_values(self, obs, action):
        q_input = torch.cat([obs, action], dim=-1)
        q1 = self.q_net1(q_input)
        q2 = self.q_net2(q_input)
        return q1, q2

    def get_action(self, obs):
        mean, std = self.forward({"obs": obs}, None, None)
        normal = torch.distributions.Normal(mean, std)
        z = normal.rsample()  # 重新参数化技巧
        action = torch.tanh(z)  # 将动作限制在 [-1, 1] 范围内
        return action
    
# 注册自定义模型
ModelCatalog.register_custom_model("custom_sac_model", CustomSACModel)
dummy_model_config = {"custom_model": "custom_sac_model"}
model = ModelCatalog.get_model_v2(
    obs_space=gym.spaces.Box(low=-1, high=1, shape=(4,)),  # 使用一个虚拟的观察空间
    action_space=gym.spaces.Box(low=-1, high=1, shape=(4,)),
    num_outputs=2,
    model_config=dummy_model_config,
    framework="torch")
# 初始化 Ray
ray.shutdown()
ray.init()

# 创建 SAC 配置
config = (
    SACConfig()
    .resources(num_gpus=1)
    .environment("Pendulum-v1")  # 使用连续动作空间的环境
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)
    .training(
        model={"custom_model": "custom_sac_model"},
        replay_buffer_config={
            "type": "PrioritizedReplayBuffer",
            "alpha": 0.6,
        },
        # 其他训练参数可以根据需要添加
    )
)

# 构建PPO算法
trainer = config.build()

# 训练
for i in range(100):
    result = trainer.train()
    print(f"Iteration {i}: reward = {result['env_runners']['episode_reward_mean']}")

@uihij thanks for raising this issue and welcome to this forum.

For a better understanding of your issue, it is always helpful, if you could state also the error you are seeing when executing your script. I guess that it is an assertion error.

This assertion errors is raised because you need to subclass the SACTorchModel instead of the TorchModelV2. Could you check, if this helps you there (it of course includes overriding all methods of that class if needed).

Furthermore, you are working on the old stack, which will soon be deprecated and not supported anymore. We highly advice new users to jumpy straight away onto the new API stack that comes with better features and full support.

Thank you so much your advice. I need to subclass the SACTorchModel for my purpose. But for the new stack, my env belongs to multi-agent. Thus, I dont know how to delopy such multi-agent enviroment into new stack under SAC config. The multi-agent demo for SAC config would be better.