Sparse recurrent state

I’d like to try implementing sparse recurrent states in rllib. Since recurrent states are zero-padded to the longest sequence, representing this using sparse matrices (scipy.sparse.coo_matrix) should greatly reduce memory usage and increase speed. Would someone be able to point me to where recurrent states are handed in the code?

EDIT: I think the easiest way to do this might be to use two variable-length state vectors for the indices and values (COO-format), however rllib does not support a variable state shape.

class SparseRNNModel(TorchModelV2):
  ...
  def get_initial_state(self):
    # Empty, zero-length tensors representing a sparse tensor using COO format
    idxs = torch.zeros(0)
    values = torch.zeros(0)
    return [idxs, values]

  def forward(self, ... state):
    idxs, values = state
    sparse_state = torch.sparse(idxs, values)
    # Forward thru sparse version of an LSTM
    # Operations on zero'd blocks (e.g. multiply, tanh, etc) are not executed
    output, sparse_state = self.sparse_lstm(input_dict["obs_flat"], sparse_state)
    # sparse_state has changed shape, as idxs and values have increased in size
    # rllib can't handle the state changing size
    state = [sparse_state.indices(), sparse_state.values()]
    return logits, state

Hi @smorad,

I agree that it is time to deal with the padded states. I personally think that it would be more straightforward to just remove the padding of the states.

Is there some benefit you see to switching to a sparse representation beyond just dealing with the padded zeros?

I would figure that the non-padded component of the states in unlikely to be sparse right? If that is the case, then I would think that just removing the padding would be a more straightforward way to deal with the issue.

My recurrent state is embedded in a graph with something like ~10% edge density. So a sparse representation should reduce my memory usage by 90%. But to be honest, I can see a lot of potential use cases where the state variable would benefit from not being converted into a dense numpy array, such as a map representation in a navigation problem, a memory object in an SDNC, a tensor of LongTensors (rllib converts all state tensors to torch.float), a dict, etc. Maybe it makes sense to serialise the state instead of converting it into a numpy.ndarray.

With regards to removing padding: I’m not sure there are any “great” solutions. AFAIK it is not easy to represent/operate on variable-sized tensors, hence the need for padding. You could try something like torch.NestedTensor which allows for variable-size dense tensors (e.g. shape = (B,T,feat) where T is a varying number of timesteps in batch B). However, this does not have torch GPU or autograd support, and is missing some basic operations. Numpy masked array requires 2x the memory as it needs to hold the mask in memory as well. Sparse matrices have their own issues, namely reduced support for more advanced operations. But they are more feature-complete than NestedTensor, use much less memory than MaskedArray, and a relatively computationally efficient.

@smorad good points. I guess I was considering your suggestion from the perspective of models and algorithms that currently exist in rllib but I see your point from the perspective of custom/future agents or models that will be created.

I think sparse matrices are perhaps the best way to deal with padding, even in the currently-existing rllib algorithms. Rllib is constantly shipping SampleBatches over the network or between processes, so I think it makes sense to compress them if possible. Sparse matrices can always be quickly reconstructed into dense matrices if needed.