Hello, all. I was working on a project and noticed an out of memory error that threw in general_advantage_estimation.py
. After looking at the code, I found that, on line 96
, the entire batch is fed into the encoder and value head at once. Further, I noticed that gradients are calculated in the __call__
method, despite (at least, as far as I can tell) not needing then on either of its products (ADVANTAGES
and VALUE_TARGETS
).
Is there a reason we calculate gradients here, and pass the observations to the value head in one big batch?
I modified the file to prevent gradient calculation and batch inputs to the value head. I’ve tested these changes and everything seems to work. I’d be happy to clean up my code and submit a PR if it’d be useful.