Hi, I’m trying to modify the IMPALA architecture to not use the default NN policy, but a custom (simpler) one (softmax function over some handmade features).
Because this policy is very simple and I can update it without pytorch, the easiest way I find to do it (tell me if I’m wrong) is to compute this new policy in a custom callback, inside the on_learn_on_batch method, where I have almost everything I need to update the custom policy: actions, rewards and observations.
The only thing I miss is the actual value of the value function branch for each of these observations.
In short, I plan to ignore the policy branch of the IMPALA’s neural networks, but use its Value function branch to compute a simpler policy (no NN). Diagram below:
I would like to use the value_function method to get these values manually, but I couldn’t find a way to save these values at inference, as they are computed automatically after forward() is called, and could not compute them again after that.
Edit: One thing I didn’t mentioned explicitly is that I need the value function values of the observations in the batch from which it’s learning, paired in order (like with obs, action, reward which are sampled keeping that pairing in time).
Thanks @mannyv , I found another way around, but your info was useful to test alternatives.
Another approach I think is more general and clear to do this is to actually make a custom trainable using the IMPALA one as a base, and compute there the update of the simple policy I’m trying to implement.
That way, I have access to the value function values as used in the original training process, without running additional passes to the network.
The only drawback is how to pass the simple policy computed parameters to the workers, but a solution I found was saving them as a pytorch file while learning, and load the file inside the workers’ network at inference to use always the most updated parameters.
Maybe there is a better way to pass these parameters around: like saving them in the trainable, access them from the simulator environment, and pass them to the network as observations. Not sure if this is possible, but I will try.