Custom TF model with tf.keras.layers.Embedding

Hi community,

Does anyone have experience in using a custom TF model with tf.keras.layers.Embedding?
My input (i.e. tf.keras.layers.Input) to the model contains some (categorical) integer values which denote, e.g., a type or a label. I want to embed such integer values with a tf.keras.layers.Embedding layer and I know at which positions in the input the corresponding values are located.

What is a best practice to do this?
I thought of slicing the input tensor, using slices as input to tf.keras.layers.Embedding layer and finally concatenating everything for further processing? Does such an approach even work and is reasonable?

Hey @klausk55 , I think it’s pretty straight forward. One of the problems, though, would be that RLlib automatically adds a one-hot preprocessor :face_with_raised_eyebrow: .
We can try adding a small example to the lib, see whether we can make this work …

1 Like

The embedding model code should be something like this, but I’ll still have to figure out how to avoid the pre-processor (something we would love to get rid of completely soon anyways :slight_smile: ).

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

class EmbeddingModel(tf.keras.Model if tf else object):
    """Example of an embedding model that takes ints as inputs."""

    def __init__(self,
                 vocab_size: int = 1000,
                 embedding_dim: int = 256):

        # Map (int) observations to an embedding vector (of size `embedding_dim`).
        # The incoming ints mey have values between 0 and `vocab_size` - 1.
        self._embedding_layer = tf.keras.layers.Embedding(

        # Postprocess embedding output with a hidden layer and
        # compute values.
        self._logits = tf.keras.layers.Dense(
        self._values = tf.keras.layers.Dense(1, activation=None, name="values")

    def call(self, inputs: SampleBatch):
        obs = inputs[SampleBatch.OBS]
        embedding_out = self._embedding_layer(obs)
        logits = self._logits(embedding_out)
        values = self._values(embedding_out)
        return logits, [], {SampleBatch.VF_PREDS: tf.reshape(values, [-1])}
1 Like

This is already using the new native keras model support :slight_smile:

Sorry @sven1977, but where occurs here the problem of a one-hot preprocessor?
Do you mean that an integer input would be preprocessed into an one-hot-vector?

Yes, exactly, I haven’t checked yet. I think there is a way to disable that.

Ah interesting… is the use of TFModelV2 obsolete?

Not yet (and won’t be for a long time), but yes, we would like to completely obsolete this API and just accept plain keras or torch models instead (WIP).

On preprocessors: RLlib “provides” (enforces) default preprocessors on some input spaces (basically flattens everything it sees, including one-hotting int spaces). Custom preprocessors are already deprecated and we are also working on deprecating the default (built-in) ones as well as they confuse the hell out of users. :slight_smile:

Okay, but as far as I know this only happens for Discrete and MultiDiscrete obs spaces, right?
To be honest, I have so far just thought about slicing input tensors resulting from Box spaces in the forward call :sweat_smile:

Ah, yes, makes perfect sense! Didn’t think about int Boxes :smiley: