Question about how to use custom models

This is a followup to this post. I’ve been trying to use a custom model and custom env to run with PPO, but it seems that the custom model is not being used here. Here is my code:

import ray
from ray import tune
from img_env import ImageReferentialGame
from models.speaker_listener_net import CommAgent

from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import PolicySpec


# Register custom model and envs
ModelCatalog.register_custom_model("comm_agent", CommAgent)

config = {
    "env": ImageReferentialGame,
    "env_config": {
        "n_vocab": 10,
        "max_len": 20,
        "n_distractors": 1,
        "n_agents": 2,
        "data_root": "/home/xl3942/Documents/TorchData/"
    },
    "_disable_preprocessor_api": True,
    "multiagent": {
        "policies":{
            "agent1": PolicySpec(),
            "agent2": PolicySpec(),
        },
        "model": {"custom_model": "comm_agent"},
        "custom_model_config": {
            "hidden": 64,
            "n_vocab": 10,
            "max_len": 20,
            "n_distractors": 1
        },
        "policy_mapping_fn": (lambda aid, **kwargs: aid)
    },
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 0,
}

ray.init()
results = tune.run("PPO", config=config, verbose=1)

I’m getting an error from complex_input_net, which means that the custom model is not properly used. 2 questions about this:

  1. Is this way of setting up custom model correct?
  2. Putting away the issue of not using the custom model, why would there still be errors on obs dimension mismatch? The error I get is (truncating some of the trace for simplicity):
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/execution/rollout_ops.py", line 75, in sampler
    yield workers.local_worker().sample()
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 753, in sample
    batches = [self.input_reader.next()]
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 103, in next
    batches = [self.get_data()]
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 233, in get_data
    item = next(self._env_runner)
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 622, in _env_runner
    eval_results = _do_policy_eval(
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 1035, in _do_policy_eval
    policy.compute_actions_from_input_dict(
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 300, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 364, in _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches,
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/xl3942/anaconda3/envs/CommAgent/lib/python3.8/site-packages/ray/rllib/models/torch/complex_input_net.py", line 158, in forward
    outs.append(torch.reshape(component, [-1, self.flatten[i]]))
RuntimeError: shape '[-1, 6144]' is invalid for input of size 9216

I thought the complex input net is supposed to handle tuple obs space?

Hi @Aceticia,

The model key is a top level key in the config but you have put it in the multiagent sub-dictionary.

1 Like

Hey @Aceticia , try:

config = {
    "env": ImageReferentialGame,
    "env_config": {
        "n_vocab": 10,
        "max_len": 20,
        "n_distractors": 1,
        "n_agents": 2,
        "data_root": "/home/xl3942/Documents/TorchData/"
    },
    "model": {
        "custom_model": "comm_agent",
        "custom_model_config": {
            "hidden": 64,
            "n_vocab": 10,
            "max_len": 20,
            "n_distractors": 1
        },
    },
    "_disable_preprocessor_api": True,
    "multiagent": {
        "policies":{
            "agent1": PolicySpec(),
            "agent2": PolicySpec(),
        },
        "policy_mapping_fn": (lambda aid, **kwargs: aid)
    },
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 0,
}

instead.

Also note that with the custom Model in place, the ComplexInputNet will not be used anymore. It’s one of RLlib’s default models (amongst: FCNet and VisionNet), which are only used if the custom_model key is not defined.

1 Like