Hello all!
I am trying to implement a custom model that uses a memory module akin to those in an NTM / DNC (an example: GitHub - loudinthecloud/pytorch-ntm: Neural Turing Machines (NTM) - PyTorch Implementation). Note how the memory state M is handled
class NTMMemory(nn.Module):
"""Memory bank for NTM."""
def __init__(self, N, M):
"""Initialize the NTM Memory matrix.
The memory's dimensions are (batch_size x N x M).
Each batch has it's own memory matrix.
:param N: Number of rows in the memory.
:param M: Number of columns/features in the memory.
"""
super(NTMMemory, self).__init__()
self.N = N
self.M = M
# The memory bias allows the heads to learn how to initially address
# memory locations by content
self.register_buffer('mem_bias', torch.Tensor(N, M))
# Initialize memory bias
stdev = 1 / (np.sqrt(N + M))
nn.init.uniform_(self.mem_bias, -stdev, stdev)
def reset(self, batch_size):
"""Initialize memory from bias, for start-of-sequence."""
self.batch_size = batch_size
self.memory = self.mem_bias.clone().repeat(batch_size, 1, 1)
So the module has a learned initial hidden state, and to activate the model on a batch we need to copy that hidden state batch_size
number of times. Every sequence being processed starts with the same learned initial state, but the controller operating on it will be reading / writing different data to it depending on what’s in that sequence, and so each sequence requires it’s own memory state
I’ve been looking into how custom models are handled in RLlib, and specifically dissecting this example: ray/rnn_model.py at master · ray-project/ray · GitHub. On line 109 we initialize the hidden state
@override(ModelV2)
def get_initial_state(self):
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
]
return h
The public API for ModelV2 exposes the get_initial_state
function to initialize recurrent models, but it doesn’t take any kind of config or kwargs, so i can’t get the batch size to initialize the memory module.
Should there be config support for get_initial_state
in ModelV2?
And / or
Am I going about this the right way?
Thank you!
Gus