I’d like to try implementing sparse recurrent states in rllib. Since recurrent states are zero-padded to the longest sequence, representing this using sparse matrices (
scipy.sparse.coo_matrix) should greatly reduce memory usage and increase speed. Would someone be able to point me to where recurrent states are handed in the code?
EDIT: I think the easiest way to do this might be to use two variable-length state vectors for the indices and values (COO-format), however rllib does not support a variable state shape.
class SparseRNNModel(TorchModelV2): ... def get_initial_state(self): # Empty, zero-length tensors representing a sparse tensor using COO format idxs = torch.zeros(0) values = torch.zeros(0) return [idxs, values] def forward(self, ... state): idxs, values = state sparse_state = torch.sparse(idxs, values) # Forward thru sparse version of an LSTM # Operations on zero'd blocks (e.g. multiply, tanh, etc) are not executed output, sparse_state = self.sparse_lstm(input_dict["obs_flat"], sparse_state) # sparse_state has changed shape, as idxs and values have increased in size # rllib can't handle the state changing size state = [sparse_state.indices(), sparse_state.values()] return logits, state