You can get L2 loss with the custom loss API as follows.
@override(ModelV2)
def custom_loss(self, policy_loss,
loss_inputs):
l2_lambda = 0.01
l2_reg = torch.tensor(0.)
for param in self.parameters():
l2_reg += torch.norm(param)
self.l2_loss = l2_lambda * l2_reg
assert self.l2_loss.requires_grad, "l2 loss no gradient"
custom_loss = self.l2_loss
# depending on input add loss
if self.hascustomloss: #in case you want to only regularize base on a config, ...
if isinstance(policy_loss, list):
return [single_loss+custom_loss for single_loss in policy_loss]
else:
return policy_loss+custom_loss
return policy_loss
def metrics(self):
metrics = {
"weight_loss": self.l2_loss.item(),
}
# you can print them to command line here. with Torch models its somehow not reportet to the logger
print(metrics)
return metrics