Using custom_model_config with DQN and a custom Keras model

Hi - Learning Ray and Tune, I’ve been working through the MultiAgentArena example in the Tutorial from the 2021 Ray summit (rllib_tutorials/ray_summit_2021 at main · sven1977/rllib_tutorials · GitHub).

I’m trying to implement custom_model_config using a custom Keras model with DQN, but the model is returning the error “init got an unexpected keyword argument <custom_model_config key>”. The error is specific to the key contained within that config dict; when I pass in a blank custom_model_config dict (e.g. “custom_model_config”: {}), I don’t receive an error.

I’ve been able to use custom_model_config successfully in Torch with DQN and PPO and in Keras with PP0, but those approaches fail in Keras with DQN. Thinking there might be an algorithm difference, I’ve tried a couple different ways of passing the custom_model_config as a kwarg, but to no avail.

Here’s a sample of the config:

config = {
        "env": MultiAgentArena,
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_mapping_fn,
        },       
        "model": {  
            "custom_model": MyKerasModel,
            "custom_model_config": {"layers": [128, 128]},         
        },
        "num_workers": 2, 
        "framework": "tf",  
        "lr": tune.grid_search([0.0001, 0.5]),
        "train_batch_size": tune.grid_search([128, 256]),

and the relevant lines of the model (both are taken essentially directly from the linked .ipynb tutorial above)

class MyKerasModel(TFModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kw):
        super(MyKerasModel, self).__init__(obs_space, action_space,
                                           num_outputs, model_config, name)

Any pointers would be helpful. thanks in advance.

Hi @Booker7,

Welcome to the forum.

Do you have a stack trace with error you could share?

Thanks for taking a look.
Trace follows:

(pid=1325) 2021-10-30 19:51:25,328	ERROR worker.py:421 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::DQN.__init__() (pid=1325, ip=10.228.104.183)
(pid=1325)   File "python/ray/_raylet.pyx", line 534, in ray._raylet.execute_task
(pid=1325)   File "python/ray/_raylet.pyx", line 484, in ray._raylet.execute_task.function_executor
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=1325)     return method(__ray_actor, *args, **kwargs)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 123, in __init__
(pid=1325)     Trainer.__init__(self, config, env, logger_creator)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 584, in __init__
(pid=1325)     super().__init__(config, logger_creator)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/tune/trainable.py", line 103, in __init__
(pid=1325)     self.setup(copy.deepcopy(self.config))
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 731, in setup
(pid=1325)     self._init(self.config, self.env_creator)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 152, in _init
(pid=1325)     num_workers=self.config["num_workers"])
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 819, in _make_workers
(pid=1325)     logdir=self.logdir)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/evaluation/worker_set.py", line 86, in __init__
(pid=1325)     lambda p, pid: (pid, p.observation_space, p.action_space)))
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 82, in wrapper
(pid=1325)     return func(*args, **kwargs)
(pid=1325) ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=1313, ip=10.228.104.183)
(pid=1325)   File "python/ray/_raylet.pyx", line 534, in ray._raylet.execute_task
(pid=1325)   File "python/ray/_raylet.pyx", line 484, in ray._raylet.execute_task.function_executor
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=1325)     return method(__ray_actor, *args, **kwargs)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 557, in __init__
(pid=1325)     self._build_policy_map(policy_dict, policy_config)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1345, in _build_policy_map
(pid=1325)     policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/policy/tf_policy_template.py", line 251, in __init__
(pid=1325)     get_batch_divisibility_req=get_batch_divisibility_req,
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 175, in __init__
(pid=1325)     self.model = make_model(self, obs_space, action_space, config)
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_tf_policy.py", line 186, in build_q_model
(pid=1325)     or config["exploration_config"]["type"] == "ParameterNoise")
(pid=1325)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/models/catalog.py", line 470, in get_model_v2
(pid=1325)     **customized_model_kwargs,
(pid=1325) TypeError: __init__() got an unexpected keyword argument 'layer'
(pid=1313) 2021-10-30 19:51:25,322	ERROR worker.py:421 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=1313, ip=10.228.104.183)
(pid=1313)   File "python/ray/_raylet.pyx", line 534, in ray._raylet.execute_task
(pid=1313)   File "python/ray/_raylet.pyx", line 484, in ray._raylet.execute_task.function_executor
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=1313)     return method(__ray_actor, *args, **kwargs)
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 557, in __init__
(pid=1313)     self._build_policy_map(policy_dict, policy_config)
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1345, in _build_policy_map
(pid=1313)     policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/policy/tf_policy_template.py", line 251, in __init__
(pid=1313)     get_batch_divisibility_req=get_batch_divisibility_req,
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 175, in __init__
(pid=1313)     self.model = make_model(self, obs_space, action_space, config)
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_tf_policy.py", line 186, in build_q_model
(pid=1313)     or config["exploration_config"]["type"] == "ParameterNoise")
(pid=1313)   File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/rllib/models/catalog.py", line 470, in get_model_v2
(pid=1313)     **customized_model_kwargs,
(pid=1313) TypeError: __init__() got an unexpected keyword argument 'layer'
Traceback (most recent call last):
  File "arena_orig.py", line 464, in <module>
    "training_iteration": 5,
  File "/home/gamma/anaconda3/envs/ray_tst/lib/python3.7/site-packages/ray/tune/tune.py", line 544, in run
    raise TuneError("Trials did not complete", incomplete_trials)
ray.tune.error.TuneError: ('Trials did not complete', [DQN_MultiAgentArena_480c6_00000, DQN_MultiAgentArena_480c6_00001, DQN_MultiAgentArena_480c6_00002, DQN_MultiAgentArena_480c6_00003])