Best way to have custom value state + LSTM

Hi,

I’m doing some training using PPO, and I would like the value function to have additional states that the policy doesn’t have.

By default, the FullyConnectedNetwork looks like this (14 should be 7 here):

I slightly modify it by splitting the input layer, this way I can give additional states to the value function input. I have a state space of size 14 and split it in 2, passing the first 7 to policy_obs and the 7 last to value_obs.

This seems to work fine, however I’d like to add an lstm on top of it, and the use_lstm wrapper doesn’t work by default (cf. error below). I was wondering if there is an easier way to do this (giving additional states to the VF) that I didn’t find? That would save me from having to add an lstm manually to my custom net. Or did I miss something with the use_lstm?

ValueError: Input 0 of layer fc_value_1 is incompatible with the layer: expected axis -1 of input shape to have value 14 but received input with shape [None, 7]

Thanks!

Hi @nathanlct,

The value function look different when using the lstm wrapper than when not.

When you use the lstm wrapper there is only one set of layers going into the lstm model and the value layers use the final state coming from the lstm as the input. So instead of two heads going into the lstm you have two heads coming out of the lstm. You can see images I created of the two cases in this post What is the intended architecture of PPO vf_share_layers=False when using an LSTM.

I think this means you are going to have to add your own lstm. Luckily this is pretty straightforward with rllib. You do have to think how and where you want to do this seperation though. Are you going to add a layer before the lstm then feed the lstm a concatenation of a policy and value embedding layer? Are you going to treat them as pass through inputs and feed them after the lstm? Are you going to have two lstms? Will only your policy use an lstm and have the value function be just fc layers?

1 Like

It’s quite straightforward, but you will have to implement your own model. See ray/rnn_model.py at 4795048f1b3779658e8b0ffaa05b1eb61914bc60 · ray-project/ray · GitHub

1 Like