Running Custom Attention_net with RNNSAC

Hi Ray team,

So I’m currently using GTRXL (which I turned into a custom model) to be run with RNNSAC. It ran fine with PPO. Apparently I’m getting the “assert sequence_len is none”, (this is from the ‘forward’ method) so I’m assuming that the model is currently not receiving the right input, which probably has something to do with the policy itself. Could you guide me on how I should approach this? The TrXLNet forward looks very different, which is understandable due to the more complex output of the GTrXL.
Would overriding this to a recurrent network be a better idea than overriding the modelV2?

If one could try an example for an Attention_net to be used with the RNNSAC, that would be great.

Hi @Puttatida_M

Do you have a stack trace of the error or a reproduction script you could share.

Hi mannyv,

So the thing is, I am using a custom SAC for RNN as well, since I’m working with Tensorflow but the example were only available in Torch, so I converted the code. I am aware the attention_net could not be run with the “SAC” provided by Ray since it’s aimed to be used with the feed forward architecture (correct me if I’m wrong).

Below is the custom attention_net.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Mon Oct 18 16:15:10 2021

@author: gael
from gym.spaces import Box, Discrete, MultiDiscrete
import numpy as np
import gym
from typing import Any, Dict, Optional, Type, Union

from ray.rllib.models.modelv2 import ModelV2
from import GRUGate, RelativeMultiHeadAttention, SkipConnection
from import TFModelV2
from import RecurrentNetwork
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
import tensorflow as tf
from tensorflow import keras
from ray.rllib.utils.tf_ops import one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List

class PositionwiseFeedforward(keras.layers.Layer if tf else object):
    """A 2x linear layer with ReLU activation in between described in [1].
    Each timestep coming from the attention head will be passed through this
    layer separately.

    def __init__(self,
                 out_dim: int,
                 hidden_dim: int,
                 output_activation: Optional[Any] = None,

        self._hidden_layer = tf.keras.layers.Dense(

        self._output_layer = tf.keras.layers.Dense(
            out_dim, activation=output_activation)

    def call(self, inputs: TensorType, **kwargs) -> TensorType:
        del kwargs
        output = self._hidden_layer(inputs)
        return self._output_layer(output)

class TransformerModel(RecurrentNetwork):
    """this is an implementation of GTrXL net by Deepmind"""

    def __init__(self, observation_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name,
                 num_transformer_units: int=4, attention_dim:int=4,
                 num_heads: int=4, head_dim: int=32, position_wise_mlp_dim: int=32, 
                 memory_inference:int =50, memory_training:int=50,
                 init_gru_gate_bias: float =2.0):
        """Initializes object.
            num_transformer_units (int): The number of Transformer repeats to
                use (denoted L in [2]).
            attention_dim (int): The input and output dimensions of one
                Transformer unit.
            num_heads (int): The number of attention heads to use in parallel.
            head_dim (int): The dimension of a single(!) attention head within
                a multi-head attention unit. D
            position_wise_mlp_dim (int): The dimension of the hidden layer
                within the position-wise MLP (after the multi-head attention
                block within one Transformer unit). This is the size of the
                first of the two layers within the PositionwiseFeedforward. The
                second layer always has size=`attention_dim`.

        super().__init__(observation_space, action_space, num_outputs,
                         model_config, name)
        custom_model_config = model_config["custom_model_config"]
        self.num_transformer_units = custom_model_config["num_transformer_units"]
        self.attention_dim = custom_model_config["attention_dim"]
        self.num_heads = custom_model_config["num_heads"]
        self.memory_inference = custom_model_config["memory_inference"]
        self.memory_training = custom_model_config["memory_training"]
        self.head_dim = custom_model_config["head_dim"]
        self.max_seq_len = model_config["max_seq_len"]
        self.obs_dim = observation_space.shape[0]
        input_layer = tf.keras.layers.Input(
            shape=(None, self.obs_dim), name="inputs")
        memory_ins = [
                shape=(None, self.attention_dim),
            for i in range(self.num_transformer_units)

        # Map observation dim to input/output transformer (attention) dim.
        input_layer = tf.keras.layers.Input(
            ), name="inputs")
        memory_ins = [
            for i in range(self.num_transformer_units)

        # Map observation dim to input/output transformer (attention) dim.
        E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
        # Output, collected and concat'd to build the internal, tau-len
        # Memory units used for additional contextual information.
        memory_outs = [E_out]

        # 2) Create L Transformer blocks according to [2].
        for i in range(self.num_transformer_units):
            # RelativeMultiHeadAttention part.
            MHA_out = SkipConnection(
                name="mha_{}".format(i + 1))(
                    E_out, memory=memory_ins[i])
            # Position-wise MLP part.
            E_out = SkipConnection(
                name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
            # Output of position-wise MLP == E(l-1), which is concat'd
            # to the current Mem block (M(l-1)) to yield E~(l-1), which is then
            # used by the next transformer block.

        self._logits = None
        self._value_out = None

        self.trxl_model = tf.keras.Model(
            inputs=[input_layer] + memory_ins,
            outputs=[E_out] + memory_outs[:-1])

        self.view_requirements = {
            SampleBatch.OBS: ViewRequirement(space=observation_space),
        # Setup trajectory views (`memory-inference` x past memory outs).
        for i in range(self.num_transformer_units):
            space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
            self.view_requirements["state_in_{}".format(i)] = \
                    # Repeat the incoming state every max-seq-len times.
            self.view_requirements["state_out_{}".format(i)] = \
    def forward(self, input_dict, state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None

        # Add the time dim to observations.
        B = tf.shape(seq_lens)[0]
        observations = input_dict[SampleBatch.OBS]

        shape = tf.shape(observations)
        T = shape[0] // B
        observations = tf.reshape(observations,
                                  tf.concat([[-1, T], shape[1:]], axis=0))

        all_out = self.trxl_model([observations] + state)

        if self._logits is not None:
            out = tf.reshape(all_out[0], [-1, self.num_outputs])
            self._value_out = all_out[1]
            memory_outs = all_out[2:]
            out = tf.reshape(all_out[0], [-1, self.attention_dim])
            memory_outs = all_out[1:]

        return out, [
            tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs

    # TODO: (sven) Deprecate this once trajectory view API has fully matured.
    def get_initial_state(self) -> List[np.ndarray]:
        return []

    def value_function(self) -> TensorType:
        return tf.reshape(self._value_out, [-1])


And below is the script I used to run the test:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Mon Oct 18 09:22:15 2021

@author: gael

import argparse
import os

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.tune.registry import register_env
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from TransformerModel_Ray import TransformerModel
from Vanilla import TrXLNet
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune import CLIReporter
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.policy.policy import Policy

ModelCatalog.register_custom_model("TransformerModel", TransformerModel)
#tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
tune.register_env("RepeatInitialObsEnv", lambda c: RepeatInitialObsEnv())

    "RepeatAfterMeEnv", "RepeatInitialObsEnv", "LookAndPush",
class Counter:
    def __init__(self):
        self.count = 0
    def inc(self, n):
        self.count += n
    def get(self):
        return self.count

parser = argparse.ArgumentParser()
    help="The RLlib-registered algorithm to use.")
parser.add_argument("--env", type=str, default="sRepeatInitialObsEnv")
parser.add_argument("--num-cpus", type=int, default=2)
    choices=["tf", "tf2", "tfe", "torch"],
    help="The DL framework specifier.")
    help="Whether this script should be run as a test: --stop-reward must "
    "be achieved within --stop-timesteps AND --stop-iters.")
    help="Number of iterations to train.")
    help="Number of timesteps to train.")
    help="Reward at which we stop training.")
    help="Run without Tune using a manual train loop instead. Here,"
    "there is no TensorBoard support.")

    help="Init Ray in local mode for easier debugging.")

if __name__ == "__main__":
    args = parser.parse_args()

    #c = Counter.remote()

    #tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
    #register_env("RepeatInitialObsEnv", lambda _: RepeatInitialObsEnv())

    config = {
        "env": "RepeatInitialObsEnv",
        # This env_config is only used for the RepeatAfterMeEnv env.
        "gamma": 0.99,
        "twin_q": True,
        "clip_actions": False,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
        #"num_envs_per_worker": 20,
        #"entropy_coeff": 0.001,
        #"num_sgd_iter": 10,
        #"vf_loss_coeff": 1e-5,
        "render_env": True,
        #"horizon": 1000,
        #"#batch_mode": "complete_episodes",
        #"#prioritized_replay": False,
        #"#buffer_size": 100000,
        #"#learning_starts": 1000,
        #"train_batch_size": 480,
        #"target_network_update_freq": 480,
        "tau": 0.3,
        "model": {
            "max_seq_len": 10},
        "Q_model": {
            "custom_model": "TransformerModel",
                "max_seq_len": 10,
                "num_transformer_units": 4,
                "attention_dim": 256,
                "num_heads": 4,
                "memory_inference": 100,
                "head_dim": 32,
                "position_wise_mlp_dim": 32,
    # Model options for the policy function (see `Q_model` above for details).
    # The difference to `Q_model` above is that no action concat'ing is
    # performed before the post_fcnet stack.
        "policy_model": {
            "custom_model": "TransformerModel",
                    "max_seq_len": 10,
                    "num_transformer_units": 4,
                    "attention_dim": 256,
                    "num_heads": 4,
                    "memory_inference": 100,
                    "head_dim": 32,
                    "position_wise_mlp_dim": 32,
    stop = {
        "training_iteration": args.stop_iters,
        "timesteps_total": args.stop_timesteps,
        "episode_reward_mean": args.stop_reward

    if args.no_tune:
        # manual training loop using PPO and manually keeping track of state
        if != "PPO":
            raise ValueError("Only support --run PPO with --no-tune.")
        ppo_config = ppo.DEFAULT_CONFIG.copy()
        trainer = ppo.PPOTrainer(config=ppo_config, env=args.env)
        # run manual training loop and print results after each iteration
        for _ in range(args.stop_iters):
            result = trainer.train()
            # stop training if the target train steps or reward are reached
            if result["timesteps_total"] >= args.stop_timesteps or \
                    result["episode_reward_mean"] >= args.stop_reward:

        # Run manual test loop (only for RepeatAfterMe env).
        if args.env == "RepeatAfterMeEnv":
            print("Finished training. Running manual test/inference loop.")
            # prepare env
            env = RepeatAfterMeEnv(config["env_config"])
            obs = env.reset()
            done = False
            total_reward = 0
            # start with all zeros as state
            num_transformers = config["model"][
            init_state = state = [
                np.zeros([100, 32], np.float32)
                for _ in range(num_transformers)
            # run one iteration until done
            print(f"RepeatAfterMeEnv with {config['env_config']}")
            while not done:
                action, state_out, _ = trainer.compute_single_action(
                    obs, state)
                next_obs, reward, done, _ = env.step(action)
                print(f"Obs: {obs}, Action: {action}, Reward: {reward}")
                obs = next_obs
                total_reward += reward
                state = [
                    np.concatenate([state[i], [state_out[i]]], axis=0)[1:]
                    for i in range(num_transformers)
            print(f"Total reward in test episode: {total_reward}")

    # Run with Tune for auto env and trainer creation and TensorBoard.
        results =, config=config, stop=stop, verbose=2, progress_reporter = CLIReporter())

        if args.as_test:
            print("Checking if learning goals were achieved")
            check_learning_achieved(results, args.stop_reward)


Here I attached the error. The modified RNN-SAC policy and model ran fine with my recurrent model, so I doubt that’s the issue. But do let me know if you want to use them as well.

(pid=6037) 2021-10-25 10:50:44,036	ERROR -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RNNSACTrainer.__init__() (pid=6037, ip=
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 137, in __init__
(pid=6037)     Trainer.__init__(self, config, env, logger_creator)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 611, in __init__
(pid=6037)     super().__init__(config, logger_creator)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/tune/", line 106, in __init__
(pid=6037)     self.setup(copy.deepcopy(self.config))
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 147, in setup
(pid=6037)     super().setup(config)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 764, in setup
(pid=6037)     self._init(self.config, self.env_creator)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 176, in _init
(pid=6037)     num_workers=self.config["num_workers"])
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/", line 852, in _make_workers
(pid=6037)     logdir=self.logdir)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/evaluation/", line 111, in __init__
(pid=6037)     spaces=spaces,
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/evaluation/", line 439, in _make_worker
(pid=6037)     spaces=spaces,
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/evaluation/", line 587, in __init__
(pid=6037)     seed=seed)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/evaluation/", line 1383, in _build_policy_map
(pid=6037)     conf, merged_conf)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/policy/", line 134, in create_policy
(pid=6037)     observation_space, action_space, merged_config)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/policy/", line 251, in __init__
(pid=6037)     get_batch_divisibility_req=get_batch_divisibility_req,
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/policy/", line 288, in __init__
(pid=6037)     is_training=in_dict.is_training)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/sac/", line 130, in get_distribution_inputs_and_class
(pid=6037)     model.get_policy_output(model_out, states_in["policy"], seq_lens)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/agents/sac/", line 130, in get_policy_output
(pid=6037)     return self.action_model(model_out, state_in, seq_lens)
(pid=6037)   File "/home/gael/miniconda3/envs/pvc2/lib/python3.7/site-packages/ray/rllib/models/", line 243, in __call__
(pid=6037)     res = self.forward(restored, state or [], seq_lens)
(pid=6037)   File "/home/gael/AttentionModel/RLLIB_Transformer/", line 183, in forward
(pid=6037)     assert seq_lens is not None
(pid=6037) AssertionError
Traceback (most recent call last):