TypeError with custom env on SAC Algorithm

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Versions / Dependencies
ray 2.8.1 (tried with 2.7.0 too)
pettingzoo 1.24.1
torch 2.0.1+cu118

I’ve created a custom environment and works on some algorithms from RLlib but if I try with SAC I get the following error:

TypeError: TorchModelV2._init_() got an unexpected keyword argument ‘policy_model_config’

I can’t figure out why I get this only for SAC because in the SACConfig I don’t set up the ‘policy_model_config’. Meanwhile by trying it on CartPole-v1 it works without any problems.

This is the code I’m using:

import irs
import os
import ray
from ray import tune
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.models.action_mask_model import TorchActionMaskModel
from ray.rllib.utils.framework import try_import_torch
from ray.tune.registry import register_env
from ray.rllib.utils import check_env
from supersuit.multiagent_wrappers import pad_action_space_v0

torch, nn = try_import_torch()

ModelCatalog.register_custom_model("am_model", TorchActionMaskModel)

env_name = "irs"
def env_creator():
    return irs.env()

env = PettingZooEnv(pad_action_space_v0(env_creator()))
test_env = PettingZooEnv(pad_action_space_v0(env_creator()))

register_env(env_name, lambda config: env)

obs_space = test_env.observation_space
act_space = test_env.action_space

check_env(test_env)

ray.shutdown()
ray.init()

alg_name = "SAC"

config = (
    SACConfig()
    .environment(env=env_name, disable_env_checking=True)
    .training(
        model={"custom_model": "am_model"},
    )
    .multi_agent(
        policies={
            "atk": (None, obs_space["atk"], act_space["atk"], {}),
            "def": (None, obs_space["def"], act_space["def"], {}),
        },
        policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
    )
    .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    .framework(framework="torch")
)

results = tune.run(
    alg_name,
    name="SAC",
    stop={"timesteps_total": 1000},
    checkpoint_freq=10,
    config=config.to_dict(),
)

Any idea on what can I do to try to fix this problem?