Best way to have custom value state + LSTM

Hi @mannyv, thanks for your quick response!

I have an idea and I am just about to run my custom model to see if this will work.
Would it be ok if I just concat my value_hidden_states with action_hidden_states and send them together away upon each forward_run call?

More concretely:

# Two separate LSTMs with two separate branches...

            self.actor_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.actor_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.value_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.value_layers[-1]._model[0].weight.new(1, self.lstm_state_size).zero_().squeeze(0)

    @override(ModelV2)
    def value_function(self):
        assert self._values is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._values), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):

        self._features, [h1, c1] = self.actor_lstm(
            self.actor_layers(inputs), [torch.unsqueeze(state[0], 0),
                torch.unsqueeze(state[1], 0)])
        action_out = self.action_branch(self._features)

        self._values, [h2, c2] = self.value_lstm(
            self.value_layers(inputs), [torch.unsqueeze(state[2], 0),
                torch.unsqueeze(state[3], 0)])

        return action_out, [torch.squeeze(h1, 0), torch.squeeze(c1, 0), torch.squeeze(h2, 0), torch.squeeze(c2, 0)]

If this doesn’t work out, should I change my view_requirement or anything?