@mannyv @Lars_Simon_Zehnder
I overwrote the on_learn_on_batch method
:
def on_learn_on_batch(self, *,
policy,
train_batch,
result,
**kwargs) -> None:
grads, fetches = policy.compute_gradients(train_batch)
for name, param in policy.model.named_parameters():
if "time2vec" in name: # Skip time2vec, we dont use this here
continue
_isfinite = torch.isfinite(param.grad).all().item()
if not _isfinite:
print(f"grad for {name} not finite\n{param.grad}")
print(f"{name}/grad/max={param.grad.max()}")
print(f"{name}/grad/min={param.grad.min()}")
print(f"{name}/grad/isfinite={_isfinite}")
raise Exception(f"grad for {name} is not finite")
And this gives me this information:
(PPO pid=94) grad for lin1.weight not finite
(PPO pid=94) tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]])
(PPO pid=94) lin1.weight/grad/max=nan
(PPO pid=94) lin1.weight/grad/min=nan
(PPO pid=94) lin1.weight/grad/isfinite=False
The question is: What the hack is happening here? Looks like I have to dick into the backpropagation process now to see what is happening in there…
But first I ll check if the weights are initialized.
UPDATE:
The weights get initialized correctly. Looks like the problem is in the backpropagation. I ll copy the whole code from compute_gradients
and _multi_gpu_parallel_grad_calc
into my on_learn_on_batch
function to make some custom debug outputs.