Hi @mannyv, thanks for your quick response!
I have an idea and I am just about to run my custom model to see if this will work.
Would it be ok if I just concat my value_hidden_states with action_hidden_states and send them together away upon each forward_run
call?
More concretely:
# Two separate LSTMs with two separate branches...
self.actor_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.actor_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.value_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.value_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0)
@override(ModelV2)
def value_function(self):
assert self._values is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._values), [-1])
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
self._features, [h1, c1] = self.actor_lstm(
self.actor_layers(inputs), [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(self._features)
self._values, [h2, c2] = self.value_lstm(
self.value_layers(inputs), [torch.unsqueeze(state[2], 0),
torch.unsqueeze(state[3], 0)])
return action_out, [torch.squeeze(h1, 0), torch.squeeze(c1, 0), torch.squeeze(h2, 0), torch.squeeze(c2, 0)]
If this doesn’t work out, should I change my view_requirement or anything?