@mannyv @Lars_Simon_Zehnder

I overwrote the `on_learn_on_batch method`

:

```
def on_learn_on_batch(self, *,
policy,
train_batch,
result,
**kwargs) -> None:
grads, fetches = policy.compute_gradients(train_batch)
for name, param in policy.model.named_parameters():
if "time2vec" in name: # Skip time2vec, we dont use this here
continue
_isfinite = torch.isfinite(param.grad).all().item()
if not _isfinite:
print(f"grad for {name} not finite\n{param.grad}")
print(f"{name}/grad/max={param.grad.max()}")
print(f"{name}/grad/min={param.grad.min()}")
print(f"{name}/grad/isfinite={_isfinite}")
raise Exception(f"grad for {name} is not finite")
```

And this gives me this information:

```
(PPO pid=94) grad for lin1.weight not finite
(PPO pid=94) tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
(PPO pid=94) [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]])
(PPO pid=94) lin1.weight/grad/max=nan
(PPO pid=94) lin1.weight/grad/min=nan
(PPO pid=94) lin1.weight/grad/isfinite=False
```

The question is: What the hack is happening here? Looks like I have to dick into the backpropagation process now to see what is happening in there…

But first I ll check if the weights are initialized.

UPDATE:

The weights get initialized correctly. Looks like the problem is in the backpropagation. I ll copy the whole code from `compute_gradients`

and `_multi_gpu_parallel_grad_calc`

into my `on_learn_on_batch`

function to make some custom debug outputs.