How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Versions / Dependencies
ray 2.8.1 (tried with 2.7.0 too)
pettingzoo 1.24.1
torch 2.0.1+cu118
I’ve created a custom environment and works on some algorithms from RLlib but if I try with SAC I get the following error:
TypeError: TorchModelV2._init_() got an unexpected keyword argument ‘policy_model_config’
I can’t figure out why I get this only for SAC because in the SACConfig I don’t set up the ‘policy_model_config’. Meanwhile by trying it on CartPole-v1 it works without any problems.
This is the code I’m using:
import irs
import os
import ray
from ray import tune
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.models.action_mask_model import TorchActionMaskModel
from ray.rllib.utils.framework import try_import_torch
from ray.tune.registry import register_env
from ray.rllib.utils import check_env
from supersuit.multiagent_wrappers import pad_action_space_v0
torch, nn = try_import_torch()
ModelCatalog.register_custom_model("am_model", TorchActionMaskModel)
env_name = "irs"
def env_creator():
return irs.env()
env = PettingZooEnv(pad_action_space_v0(env_creator()))
test_env = PettingZooEnv(pad_action_space_v0(env_creator()))
register_env(env_name, lambda config: env)
obs_space = test_env.observation_space
act_space = test_env.action_space
check_env(test_env)
ray.shutdown()
ray.init()
alg_name = "SAC"
config = (
SACConfig()
.environment(env=env_name, disable_env_checking=True)
.training(
model={"custom_model": "am_model"},
)
.multi_agent(
policies={
"atk": (None, obs_space["atk"], act_space["atk"], {}),
"def": (None, obs_space["def"], act_space["def"], {}),
},
policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
.framework(framework="torch")
)
results = tune.run(
alg_name,
name="SAC",
stop={"timesteps_total": 1000},
checkpoint_freq=10,
config=config.to_dict(),
)
Any idea on what can I do to try to fix this problem?