How to use custom model with IMPALA trainer (first time using rllib)

Hi everyone, we’re trying to use the rllib implementation of IMPALA with a custom model and a custom environment. Currently, we’re facing strange errors when using a custom model (simple Actor Critic with FC layers).

Running the code below causes a ValueError: Shape (49, 4, 1) must have rank 2 inside the v-trace algorithm.

Unfortunately, there are no examples for IMPALA with a custom model on the repository. Does anyone have an idea what is missing in our code?

import tensorflow as tf
import ray
import ray.rllib
from import TFModelV2
from ray.rllib.models import ModelCatalog
from ray.rllib.agents.impala import ImpalaTrainer

class ActorCritic(TFModelV2):

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
            obs_space=obs_space, action_space=action_space,
            num_outputs=num_outputs, model_config=model_config,

        inputs = tf.keras.layers.Input(shape=obs_space.shape)
        base = tf.keras.layers.Dense(128, activation=tf.nn.relu)(inputs)
        # TODO why is the action_space different than the num_outputs?
        actor = tf.keras.layers.Dense(num_outputs)(base)
        critic = tf.keras.layers.Dense(1)(base)
        self.model = tf.keras.Model(inputs=inputs, outputs=[actor, critic])
        self.critic_out = None

    def forward(self, input_dict, state, seq_lens):
        actor_out, self.critic_out = self.model(input_dict["obs_flat"])
        return actor_out, []  # state

    def value_function(self):
        return self.critic_out

if __name__ == "__main__":


    # custom model
    # see example
        "actor_critic", ActorCritic)

    config = {
        "env": "Pendulum-v0",
        "lr": 0.001,
        "num_gpus": 0,
        "num_workers": 1,
        "num_envs_per_worker": 1,
        "model": {
            "custom_model": "actor_critic",
            # arguments passed to constructor of custom model
            "custom_model_config": {},},
        # enable eager execution
        "framework": "tfe",
        # "framework": "tf",
        "log_level": "INFO",

    trainer = ImpalaTrainer(config=config)

Also, it would be awesome if someone could post a working example of IMPALA with a custom model and a custom env!

Hi @thomasbbrunner,

Welcome to the forum. Do you have a stack trace with the error you are getting?

Hi @mannyv, thanks for the quick reply. Yes, here is the complete output:

2021-09-08 10:59:58.958761: W tensorflow/stream_executor/platform/default/] Could not load dynamic library ''; dlerror: cannot open shared object file: No such file or directory
2021-09-08 10:59:58.958786: I tensorflow/stream_executor/cuda/] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2021-09-08 11:00:02,391	INFO -- Executing eagerly, with eager_tracing=False
(pid=360458) 2021-09-08 11:00:04,783	INFO -- Validating sub-env at vector index=0 ... (ok)
(pid=360458) 2021-09-08 11:00:04,785	INFO -- TF-eager Policy (worker=1) running on CPU.
(pid=360458) 2021-09-08 11:00:04,786	INFO -- Wrapping <class '__main__.ActorCritic'> as None
Traceback (most recent call last):
  File "/home/user/code/project/./src/common/", line 59, in <module>
    trainer = ImpalaTrainer(config=config)
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 137, in __init__
    Trainer.__init__(self, config, env, logger_creator)
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 603, in __init__
    super().__init__(config, logger_creator)
  File "/home/user/.local/lib/python3.9/site-packages/ray/tune/", line 105, in __init__
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 147, in setup
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 748, in setup
    self._init(self.config, self.env_creator)
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 171, in _init
    self.workers = self._make_workers(
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/", line 830, in _make_workers
    return WorkerSet(
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/evaluation/", line 83, in __init__
    remote_spaces = ray.get(self.remote_workers(
  File "/home/user/.local/lib/python3.9/site-packages/ray/_private/", line 89, in wrapper
    return func(*args, **kwargs)
  File "/home/user/.local/lib/python3.9/site-packages/ray/", line 1623, in get
    raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=360458, ip=
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/evaluation/", line 573, in __init__
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/evaluation/", line 1371, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 132, in create_policy
    class_(observation_space, action_space, merged_config)
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 329, in __init__
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 793, in _initialize_loss_from_dummy_batch
    self._loss(self, self.model, self.dist_class, train_batch)
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 195, in build_vtrace_loss
    policy.loss = VTraceLoss(
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 78, in __init__
    self.vtrace_returns = vtrace.multi_from_logits(
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 246, in multi_from_logits
    vtrace_returns = from_importance_weights(
  File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 326, in from_importance_weights
  File "/home/user/.local/lib/python3.9/site-packages/tensorflow/python/framework/", line 1041, in assert_has_rank
    raise ValueError("Shape %s must have rank %d" % (self, rank))
ValueError: Shape (49, 4, 1) must have rank 2
(pid=360458) 2021-09-08 11:00:05,042	ERROR -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=360458, ip=
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/evaluation/", line 573, in __init__
(pid=360458)     self._build_policy_map(
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/evaluation/", line 1371, in _build_policy_map
(pid=360458)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 132, in create_policy
(pid=360458)     class_(observation_space, action_space, merged_config)
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 329, in __init__
(pid=360458)     self._initialize_loss_from_dummy_batch(
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/policy/", line 793, in _initialize_loss_from_dummy_batch
(pid=360458)     self._loss(self, self.model, self.dist_class, train_batch)
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 195, in build_vtrace_loss
(pid=360458)     policy.loss = VTraceLoss(
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 78, in __init__
(pid=360458)     self.vtrace_returns = vtrace.multi_from_logits(
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 246, in multi_from_logits
(pid=360458)     vtrace_returns = from_importance_weights(
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/ray/rllib/agents/impala/", line 326, in from_importance_weights
(pid=360458)     values.shape.assert_has_rank(rho_rank)
(pid=360458)   File "/home/user/.local/lib/python3.9/site-packages/tensorflow/python/framework/", line 1041, in assert_has_rank
(pid=360458)     raise ValueError("Shape %s must have rank %d" % (self, rank))
(pid=360458) ValueError: Shape (49, 4, 1) must have rank 2

Hi @thomasbbrunner,

The issue is in the value_function. You need to squeeze out the trailing singleton dimension. You can fix it by following the example from here: ray/ at 3f89f35e5269c8a9391fb98a535cde7ffd6bcd9d · ray-project/ray · GitHub

    def value_function(self):
        return tf.reshape(self.critic_out, [-1])

You’re right! This solves the issue I was having!

I had other issues, but was able to solve them. Thanks a lot!