Should there be config support for get_initial_state in ModelV2?

Hello all!

I am trying to implement a custom model that uses a memory module akin to those in an NTM / DNC (an example: GitHub - loudinthecloud/pytorch-ntm: Neural Turing Machines (NTM) - PyTorch Implementation). Note how the memory state M is handled

class NTMMemory(nn.Module):
    """Memory bank for NTM."""
    def __init__(self, N, M):
        """Initialize the NTM Memory matrix.
        The memory's dimensions are (batch_size x N x M).
        Each batch has it's own memory matrix.
        :param N: Number of rows in the memory.
        :param M: Number of columns/features in the memory.
        """
        super(NTMMemory, self).__init__()

        self.N = N
        self.M = M

        # The memory bias allows the heads to learn how to initially address
        # memory locations by content
        self.register_buffer('mem_bias', torch.Tensor(N, M))

        # Initialize memory bias
        stdev = 1 / (np.sqrt(N + M))
        nn.init.uniform_(self.mem_bias, -stdev, stdev)

    def reset(self, batch_size):
        """Initialize memory from bias, for start-of-sequence."""
        self.batch_size = batch_size
        self.memory = self.mem_bias.clone().repeat(batch_size, 1, 1)

So the module has a learned initial hidden state, and to activate the model on a batch we need to copy that hidden state batch_size number of times. Every sequence being processed starts with the same learned initial state, but the controller operating on it will be reading / writing different data to it depending on what’s in that sequence, and so each sequence requires it’s own memory state

I’ve been looking into how custom models are handled in RLlib, and specifically dissecting this example: ray/rnn_model.py at master · ray-project/ray · GitHub. On line 109 we initialize the hidden state

    @override(ModelV2)
    def get_initial_state(self):
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
        ]
        return h

The public API for ModelV2 exposes the get_initial_state function to initialize recurrent models, but it doesn’t take any kind of config or kwargs, so i can’t get the batch size to initialize the memory module.

Should there be config support for get_initial_state in ModelV2?

And / or

Am I going about this the right way?

Thank you!
Gus

Hi @Michael_Gussert,

Welcome to the forum.

There is an example implementation of a DNC in rllib that might be useful for you.

Check out these two links:

Thank you! I will take a look :slight_smile: