Get value function values from IMPALA

Hi, I’m trying to modify the IMPALA architecture to not use the default NN policy, but a custom (simpler) one (softmax function over some handmade features).

Because this policy is very simple and I can update it without pytorch, the easiest way I find to do it (tell me if I’m wrong) is to compute this new policy in a custom callback, inside the on_learn_on_batch method, where I have almost everything I need to update the custom policy: actions, rewards and observations.

The only thing I miss is the actual value of the value function branch for each of these observations.

In short, I plan to ignore the policy branch of the IMPALA’s neural networks, but use its Value function branch to compute a simpler policy (no NN). Diagram below:

From this:

To this:

And update it after N iterations, as in the IMPALA learner:

Is this approach correct? and if so,

Are these VF values saved anywhere?

or do I need to save them manually at some point in the training process? (like for example, inside on_episode_step in my custom callback).

The value of the current state is indeed saved. In fact, there is a public method value_function in ModelV2 that gets this.

I understand that the value function is used inside the loss computation of the Impala trainer, but if these values can be accessed from outside it at processing-batch-of-observations time is not clear to me.

I would like to use the value_function method to get these values manually, but I couldn’t find a way to save these values at inference, as they are computed automatically after forward() is called, and could not compute them again after that.

Edit: One thing I didn’t mentioned explicitly is that I need the value function values of the observations in the batch from which it’s learning, paired in order (like with obs, action, reward which are sampled keeping that pairing in time).

I think you may be interested in the policy._value() method that is added when policies use the ValueNetworkMixin. You can find the code for it here:

Thanks @mannyv , I found another way around, but your info was useful to test alternatives.

Another approach I think is more general and clear to do this is to actually make a custom trainable using the IMPALA one as a base, and compute there the update of the simple policy I’m trying to implement.

That way, I have access to the value function values as used in the original training process, without running additional passes to the network.

The only drawback is how to pass the simple policy computed parameters to the workers, but a solution I found was saving them as a pytorch file while learning, and load the file inside the workers’ network at inference to use always the most updated parameters.

Maybe there is a better way to pass these parameters around: like saving them in the trainable, access them from the simulator environment, and pass them to the network as observations. Not sure if this is possible, but I will try.

Thank you for your replies again.