After updating from Ray 1.0.1 to 1.2, custom model stops working

Could not find an answer in the list of changes from one version to the other, but after updating from 1.01 to 1.2.x, my custom model stopped working:

At the beggining of training it raises:

TypeError: __init__() got an unexpected keyword argument 'drop_rate'

Where it is referring to the “drop_rate” parameter I use for my custom model.

        "model": {
            "custom_model": "TorchRNN",
            "custom_model_config": {
                "drop_rate": 0.5,
                "training": True
            },
ModelCatalog.register_custom_model("TorchRNN", TorchRNNModel)
class TorchRNNModel(TorchRNN, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 lstm_state_size=16,
                 fcnet_hiddens=[16,8],
                 drop_rate=0.5,
                 training=True):

I solved it by downgrading again to 1.0.1, but I would like to know if there is some changes in the way custom models work in recent versions.

Entire log:

Failure # 1 (occurred at 2021-02-15_17-13-53)
Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/tune/trial_runner.py", line 586, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/tune/ray_trial_executor.py", line 609, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/worker.py", line 1456, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): e[36mray::IMPALA.train_buffered()e[39m (pid=6668, ip=172.31.22.215)
  File "python/ray/_raylet.pyx", line 439, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 107, in __init__
    Trainer.__init__(self, config, env, logger_creator)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 486, in __init__
    super().__init__(config, logger_creator)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/tune/trainable.py", line 97, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 654, in setup
    self._init(self.config, self.env_creator)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 139, in _init
    num_workers=self.config["num_workers"])
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 731, in _make_workers
    logdir=self.logdir)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/evaluation/worker_set.py", line 81, in __init__
    lambda p, pid: (pid, p.observation_space, p.action_space)))
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
ray.exceptions.RayTaskError(TypeError): e[36mray::RolloutWorker.foreach_policy()e[39m (pid=6673, ip=172.31.22.215)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/models/catalog.py", line 469, in get_model_v2
    **customized_model_kwargs)
TypeError: __init__() got an unexpected keyword argument 'drop_rate'

During handling of the above exception, another exception occurred:

e[36mray::RolloutWorker.foreach_policy()e[39m (pid=6673, ip=172.31.22.215)
  File "python/ray/_raylet.pyx", line 439, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 480, in __init__
    policy_dict, policy_config)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1111, in _build_policy_map
    policy_map[name] = cls(obs_space, act_space, merged_conf)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 234, in __init__
    framework=framework)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/models/catalog.py", line 475, in get_model_v2
    **model_kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 120, in __init__
    model_config, name)
  File "/home/ec2-user/lr-londres/londres/models/torch_rnn_model.py", line 44, in __init__
    self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 78, in __init__
    self.weight = Parameter(torch.Tensor(out_features, in_features))
TypeError: new() received an invalid combination of arguments - got (NoneType, int), but expected one of:
 * (*, torch.device device)
      didn't match because some of the arguments have invalid types: (!NoneType!, !int!)
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, *, torch.device device)
 * (object data, *, torch.device device)

Thanks for the question @LecJackS ! I cannot reproduce your problem, though. The following works fine:

Btw, in the future, please provide a concise, self-sufficient reproduction script, so I can debug quickly. It’s really hard to copy-past people’s different, non-consecutive snippets into a script that reproduces the issue.

import ray
from ray import tune
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()



class TorchRNNModel(TorchModelV2, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 lstm_state_size=16,
                 fcnet_hiddens=[16,8],
                 drop_rate=0.5,
                 training=True):
        nn.Module.__init__(self)
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)

        self.layer = nn.Linear(obs_space.shape[0], action_space.n)

    def forward(self, input_dict, state,
                seq_lens):
        out = self.layer(input_dict["obs_flat"])
        self._value = out[:, 0]
        return out, state

    def value_function(self):
        return self._value


ray.init()

ModelCatalog.register_custom_model("my_model", TorchRNNModel)

tune.run("PPO", config={
    "env": "CartPole-v0",
    "model": {
        "custom_model": "my_model",
        "custom_model_config": {
            "drop_rate": 0.5,
            "training": True
        },
    },
    "framework": "torch",
})

Found the bug. It was caused by a forgotten "use_lstm": True, inside the model’s config, while I wasn’t using that but a custom model with an lstm manually set.

1 Like