RNN L2 weights regularization

You can get L2 loss with the custom loss API as follows.

    @override(ModelV2)
    def custom_loss(self, policy_loss,
                    loss_inputs):
        
        l2_lambda = 0.01

        l2_reg = torch.tensor(0.)
        for param in self.parameters():
            l2_reg += torch.norm(param)
        self.l2_loss = l2_lambda * l2_reg

        assert self.l2_loss.requires_grad, "l2 loss no gradient"
        
        custom_loss = self.l2_loss
        
        # depending on input add loss
        if self.hascustomloss: #in case you want to only regularize base on a config, ...
            if isinstance(policy_loss, list):
                return [single_loss+custom_loss for single_loss in policy_loss]
            else:
                return policy_loss+custom_loss

        return policy_loss

    def metrics(self):
        metrics = {
            "weight_loss": self.l2_loss.item(),
        }
       # you can print them to command line here. with Torch models its somehow not reportet to the logger
        print(metrics)
        return metrics