I am trying to implement a custom model that uses a memory module akin to those in an NTM / DNC (an example: GitHub - loudinthecloud/pytorch-ntm: Neural Turing Machines (NTM) - PyTorch Implementation). Note how the memory state M is handled
class NTMMemory(nn.Module): """Memory bank for NTM.""" def __init__(self, N, M): """Initialize the NTM Memory matrix. The memory's dimensions are (batch_size x N x M). Each batch has it's own memory matrix. :param N: Number of rows in the memory. :param M: Number of columns/features in the memory. """ super(NTMMemory, self).__init__() self.N = N self.M = M # The memory bias allows the heads to learn how to initially address # memory locations by content self.register_buffer('mem_bias', torch.Tensor(N, M)) # Initialize memory bias stdev = 1 / (np.sqrt(N + M)) nn.init.uniform_(self.mem_bias, -stdev, stdev) def reset(self, batch_size): """Initialize memory from bias, for start-of-sequence.""" self.batch_size = batch_size self.memory = self.mem_bias.clone().repeat(batch_size, 1, 1)
So the module has a learned initial hidden state, and to activate the model on a batch we need to copy that hidden state
batch_size number of times. Every sequence being processed starts with the same learned initial state, but the controller operating on it will be reading / writing different data to it depending on what’s in that sequence, and so each sequence requires it’s own memory state
I’ve been looking into how custom models are handled in RLlib, and specifically dissecting this example: ray/rnn_model.py at master · ray-project/ray · GitHub. On line 109 we initialize the hidden state
@override(ModelV2) def get_initial_state(self): h = [ self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0), self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0), ] return h
The public API for ModelV2 exposes the
get_initial_state function to initialize recurrent models, but it doesn’t take any kind of config or kwargs, so i can’t get the batch size to initialize the memory module.
Should there be config support for
get_initial_state in ModelV2?
And / or
Am I going about this the right way?