Call order for `loss_fn` and `custom_loss`

Hello everyone,

I wish to use RLlib to implement a version of multi-agent APPO that has a centralized critic and an imitation loss (comparing agent’s distribution with that of the expert’s policy). Learning from the example, I believe that I need to do two things:
(1) Extend from the currentAPPOPolicy and define a centralized critic loss function (which is the policy loss)

CCAsyncPPOImRLTorchPolicy = AsyncPPOTorchPolicy.with_updates(
    name="CCAsyncPPOImRLTorchPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic_and_ImRL,
    mixins=[
        TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
        CentralizedValueMixin, ImitationLossCoeffSchedule
    ])

(2) Create a custom model that is a sub-class of TorchModelV2, and overrides the custom_loss() method and calculated imitation from there.

Since I saw from the example that the custom_loss method has an input argument called policy_loss, so I assume this is passed in from the previous centralized critic loss, therefore loss_fn will be called before custom_loss?

Hey @sven1977 @rliaw @michaelzhiluo any thoughts here?