Hey @Astariul , strange. Yeah, it could have to do with your custom action distribution, which moves things back on the GPU in the dist.logp
call inside multi_log_probs_from_logits_and_actions
. It’s probably better to have your change in then.
As background: v-trace calculations - as per the original IMPALA paper - should be done on the CPU as these are all sequential. That’s why we do this move inside the IMPALA loss - seemingly all of a sudden - from “device” to the “cpu” (no matter what “device” is), and then back to “device” after the v-trace computations.