Using custom neural network in RLlib

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi all,

I have a custom LSTM neural network written in an older version of “Ray”, and want to use it as a policy network in my RL agents. Here is the code for my custom neural network:

class KerasConvLSTM(RecurrentNetwork):
    """
    The model used in the paper "The AI Economist: Optimal Economic Policy
    Design via Two-level Deep Reinforcement Learning"
    (https://arxiv.org/abs/2108.02755)
    We combine convolutional, fully connected, and recurrent layers to process
    spatial, non-spatial, and historical information, respectively.
    For recurrent components, each agent maintains its own hidden state.
    """

    custom_name = "keras_conv_lstm"

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        input_emb_vocab = self.model_config["custom_model_config"]["input_emb_vocab"]
        emb_dim = self.model_config["custom_model_config"]["idx_emb_dim"]
        num_conv = self.model_config["custom_model_config"]["num_conv"]
        num_fc = self.model_config["custom_model_config"]["num_fc"]
        fc_dim = self.model_config["custom_model_config"]["fc_dim"]
        cell_size = self.model_config["custom_model_config"]["lstm_cell_size"]
        generic_name = self.model_config["custom_model_config"].get("generic_name", None)

        self.cell_size = cell_size

        if hasattr(obs_space, "original_space"):
            obs_space = obs_space.original_space

        if not isinstance(obs_space, Dict):
            if isinstance(obs_space, Box):
                raise TypeError(
                    "({}) Observation space should be a gym Dict."
                    " Is a Box of shape {}".format(name, obs_space.shape)
                )
            raise TypeError(
                "({}) Observation space should be a gym Dict."
                " Is {} instead.".format(name, type(obs_space))
            )

        # Define input layers
        self._input_keys = []
        non_conv_input_keys = []
        input_dict = {}
        conv_shape_r = None
        conv_shape_c = None
        conv_map_channels = None
        conv_idx_channels = None
        found_world_map = False
        found_world_idx = False
        
        for k, v in obs_space.spaces.items():
            shape = (None,) + v.shape
            input_dict[k] = tf.keras.layers.Input(shape=shape, name=k)
            self._input_keys.append(k)
            if k == _MASK_NAME:
                pass
            elif k == _WORLD_MAP_NAME:
                conv_shape_r, conv_shape_c, conv_map_channels = (
                    v.shape[1],
                    v.shape[2],
                    v.shape[0],
                )
                found_world_map = True
            elif k == _WORLD_IDX_MAP_NAME:
                conv_idx_channels = v.shape[0] * emb_dim
                found_world_idx = True
            else:
                non_conv_input_keys.append(k)

        # Cell state and hidden state for the policy and value function networks.
        state_in_h_p = tf.keras.layers.Input(shape=(cell_size,), name="h_pol")
        state_in_c_p = tf.keras.layers.Input(shape=(cell_size,), name="c_pol")
        state_in_h_v = tf.keras.layers.Input(shape=(cell_size,), name="h_val")
        state_in_c_v = tf.keras.layers.Input(shape=(cell_size,), name="c_val")
        seq_in = tf.keras.layers.Input(shape=(), name="seq_in")

        # Determine which of the inputs are treated as non-conv inputs
        if generic_name is None:
            non_conv_inputs = tf.keras.layers.concatenate(
                [input_dict[k] for k in non_conv_input_keys]
            )
        elif isinstance(generic_name, (tuple, list)):
            non_conv_inputs = tf.keras.layers.concatenate(
                [input_dict[k] for k in generic_name]
            )
        elif isinstance(generic_name, str):
            non_conv_inputs = input_dict[generic_name]
        else:
            raise TypeError

        if found_world_map:
            assert found_world_idx
            use_conv = True
            conv_shape = (
                conv_shape_r,
                conv_shape_c,
                conv_map_channels + conv_idx_channels,
            )

            conv_input_map = tf.keras.layers.Permute((1, 3, 4, 2))(
                input_dict[_WORLD_MAP_NAME]
            )
            conv_input_idx = tf.keras.layers.Permute((1, 3, 4, 2))(
                input_dict[_WORLD_IDX_MAP_NAME]
            )

        else:
            assert not found_world_idx
            use_conv = False
            conv_shape = None
            conv_input_map = None
            conv_input_idx = None

        logits, values, state_h_p, state_c_p, state_h_v, state_c_v = (
            None,
            None,
            None,
            None,
            None,
            None,
        )

        # Define the policy and value function models
        for tag in ["_pol", "_val"]:
            if tag == "_pol":
                state_in = [state_in_h_p, state_in_c_p]
            elif tag == "_val":
                state_in = [state_in_h_v, state_in_c_v]
            else:
                raise NotImplementedError

            # Apply convolution to the spatial inputs
            if use_conv:
                map_embedding = tf.keras.layers.Embedding(
                    input_emb_vocab, emb_dim, name="embedding" + tag
                )
                conv_idx_embedding = tf.keras.layers.Reshape(
                    (-1, conv_shape_r, conv_shape_c, conv_idx_channels)
                )(map_embedding(conv_input_idx))

                conv_input = tf.keras.layers.concatenate(
                    [conv_input_map, conv_idx_embedding]
                )

                conv_model = tf.keras.models.Sequential(name="conv_model" + tag)
                assert conv_shape
                conv_model.add(
                    tf.keras.layers.Conv2D(
                        16,
                        (3, 3),
                        strides=2,
                        activation="relu",
                        input_shape=conv_shape,
                        name="conv2D_1" + tag,
                    )
                )

                for i in range(num_conv - 1):
                    conv_model.add(
                        tf.keras.layers.Conv2D(
                            32,
                            (3, 3),
                            strides=2,
                            activation="relu",
                            name="conv2D_{}{}".format(i + 2, tag),
                        )
                    )

                conv_model.add(tf.keras.layers.Flatten())

                conv_td = tf.keras.layers.TimeDistributed(conv_model)(conv_input)

                # Combine the conv output with the non-conv inputs
                dense = tf.keras.layers.concatenate([conv_td, non_conv_inputs])

            # No spatial inputs provided -- skip any conv steps
            else:
                dense = non_conv_inputs

            # Preprocess observation with hidden layers and send to LSTM cell
            for i in range(num_fc):
                layer = tf.keras.layers.Dense(
                    fc_dim, activation=tf.nn.relu, name="dense{}".format(i + 1) + tag
                )
                dense = layer(dense)

            dense = tf.keras.layers.LayerNormalization(name="layer_norm" + tag)(dense)

            lstm_out, state_h, state_c = tf.keras.layers.LSTM(
                cell_size, return_sequences=True, return_state=True, name="lstm" + tag
            )(inputs=dense, mask=tf.sequence_mask(seq_in), initial_state=state_in)

            # Project LSTM output to logits or value
            output = tf.keras.layers.Dense(
                self.num_outputs if tag == "_pol" else 1,
                activation = tf.keras.activations.linear,
                name="logits" if tag == "_pol" else "value",
            )(lstm_out)

            if tag == "_pol":
                state_h_p, state_c_p = state_h, state_c
                logits = apply_logit_mask(output, input_dict[_MASK_NAME])
            elif tag == "_val":
                state_h_v, state_c_v = state_h, state_c
                values = output
            else:
                raise NotImplementedError

        self.input_dict = input_dict

        # This will be set in the forward_rnn() call below
        self._value_out = None

        for out in [logits, values, state_h_p, state_c_p, state_h_v, state_c_v]:
            assert out is not None

        # Create the RNN model
        self.rnn_model = tf.keras.Model(
            inputs=self._extract_input_list(input_dict)
            + [seq_in, state_in_h_p, state_in_c_p, state_in_h_v, state_in_c_v],
            outputs=[logits, values, state_h_p, state_c_p, state_h_v, state_c_v],
        )
        # self.register_variables(self.rnn_model.variables)
        # self.rnn_model.summary()

    def _extract_input_list(self, dictionary):
        return [dictionary[k] for k in self._input_keys]

    def forward(self, input_dict, state, seq_lens):
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass."""
        
        padded_inputs = input_dict["obs_flat"]
        
        max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0]
                
        output, new_state = self.forward_rnn(
            add_time_dimension(padded_inputs, max_seq_lens=max_seq_len, framework="tf"),
            state,
            seq_lens,
        )

        return tf.reshape(output, [-1, self.num_outputs]), new_state

    def forward_rnn(self, inputs, state, seq_lens):
        model_out, self._value_out, h_p, c_p, h_v, c_v = self.rnn_model(
            inputs + [seq_lens] + state
        )
        return model_out, [h_p, c_p, h_v, c_v]

    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
        ]

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

I register and feed my custom neural network as a policy network of RL agents as following:

from ray.rllib.models.catalog import ModelCatalog
from rllib.tf_models import KerasConvLSTM

ModelCatalog.register_custom_model(KerasConvLSTM.custom_name, KerasConvLSTM)

policies = {
    "a": (
        None,  # uses default policy
        env_obj.observation_space,
        env_obj.action_space,
        {'clip_param': 0.3,
         'entropy_coeff': 0.025,
         'entropy_coeff_schedule': None,
         'gamma': 0.998,
         'grad_clip': 10.0,
         'kl_coeff': 0.0,
         'kl_target': 0.01,
         'lambda': 0.98,
         'lr': 0.0003,
         'lr_schedule': None,
         'model': {'custom_model': 'keras_conv_lstm',                  
                         'custom_model_config': {'fc_dim': 128,
                                                                     'idx_emb_dim': 4,
                                                                     'input_emb_vocab': 100,
                                                                     'lstm_cell_size': 128,
                                                                     'num_conv': 2, 
                                                                     'num_fc': 2},
                         'max_seq_len': 25},
         'use_gae': True,
         'vf_clip_param': 50.0,
         'vf_loss_coeff': 0.05,
         'vf_share_layers': False}  # define a custom agent policy configuration.
    ),
    "p": (
        None,  # uses default policy
        env_obj.observation_space_pl,
        env_obj.action_space_pl,
        {'clip_param': 0.3,
         'entropy_coeff': 0.125,
         'entropy_coeff_schedule': [[0, 2.0], [50000000, 0.125]],
         'gamma': 0.998,
         'grad_clip': 10.0,
         'kl_coeff': 0.0,
         'kl_target': 0.01,
         'lambda': 0.98,
         'lr': 0.0001,
         'lr_schedule': None,
         'model': {'custom_model': 'keras_conv_lstm',
                         'custom_model_config': {'fc_dim': 128,
                                                                     'idx_emb_dim': 4,
                                                                     'input_emb_vocab': 100,
                                                                     'lstm_cell_size': 128,
                                                                     'num_conv': 2, 
                                                                     'num_fc': 2},
                         'max_seq_len': 25},
         'use_gae': True,
         'vf_clip_param': 50.0,
         'vf_loss_coeff': 0.05,
         'vf_share_layers': False}  # define a custom planner policy configuration.
    )
}

policy_mapping_fun = lambda i: "a" if str(i).isdigit() else "p"

policies_to_train = ["a", "p"]

When I incorporate this custom neural network in my agents and try to generate a PPO trainer as following:

trainer = PPOTrainer(env=RLlibEnvWrapper, config=trainer_config)

It gives me the following error:

File "/Users/asataryd/miniforge3/envs/modified-ai-economist/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 259, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/Users/asataryd/Documents/ENI-Projects/Aslan/In-silico-Experimental-Economics-through-MARL-Autocurricula/Githubs/modified-ai-economist-master/tutorials/rllib/tf_models.py", line 311, in forward
    output, new_state = self.forward_rnn(
  File "/Users/asataryd/Documents/ENI-Projects/Aslan/In-silico-Experimental-Economics-through-MARL-Autocurricula/Githubs/modified-ai-economist-master/tutorials/rllib/tf_models.py", line 332, in forward_rnn
    inputs + [seq_lens] + state
  File "/Users/asataryd/miniforge3/envs/modified-ai-economist/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/asataryd/miniforge3/envs/modified-ai-economist/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 1963, in _create_c_op
    raise ValueError(e.message)

ValueError: Dimensions must be equal, but are 1682 and 128 for '{{node a_wk1/add_2}} = AddV2[T=DT_FLOAT](a_wk1/add_1, a_wk1/add_2/y)' with input shapes: [?,?,1682], [4,?,128].

I was wondering if anyone could help me how I will be able to pinpoint the reason for this error. Particularly, I was wondering if you could check implementation of “forward” and “forward_rnn” functions of the custom neural network. This is my first project using “RLlib”. Also, regarding debugging since I am running my code in Jupyter notebook, I used to use “pdb” library, but in this case, it doesn’t work. I was wondering if you could also tell me if there is any other good way of debugging of python codes from inside Jupyter notebook.

Many many thanks in advance!

Hi there, I’m having the same problem. I think it’s a preprocessor issue (I’m working on it, really :man_facepalming: )

Btw if you want to improve your code a little bit:

from ray.tune.registry import register_env
register_env("env_name", lambda _: RLlibEnvWrapper(config))
trainer = PPOTrainer(env="env_name", config=trainer_config)

with this the project should scale better :slight_smile:

Thanks @sa1g! I will check this!

I’ve been working to separate AI-Economist nn from rllib and I’m having the same problem. Looks like we are facing a preprocessor issue → we need to do padding to raw data:

  • select the obs of a single agent
  • convert it from numpy.array to tensor & pad (or vice versa)
  • feed it to the network

I think the problem is here, I’m working on it.

p.s. uncomment self.rnn_model.summary() in KerasConvLSTM so you can see the model structure (it helps debugging)
p.p.s use eager mode to check what’s the real problem :stuck_out_tongue:

Let me know if you find a solution, I’ll :slight_smile:

Thanks @sa1g for your comments! I will check these and let you know if I can solve the issue!

1 Like