Hello everyone,
I wish to use RLlib to implement a version of multi-agent APPO that has a centralized critic and an imitation loss (comparing agent’s distribution with that of the expert’s policy). Learning from the example, I believe that I need to do two things:
(1) Extend from the currentAPPOPolicy
and define a centralized critic loss function (which is the policy loss)
CCAsyncPPOImRLTorchPolicy = AsyncPPOTorchPolicy.with_updates(
name="CCAsyncPPOImRLTorchPolicy",
postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic_and_ImRL,
mixins=[
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
CentralizedValueMixin, ImitationLossCoeffSchedule
])
(2) Create a custom model that is a sub-class of TorchModelV2
, and overrides the custom_loss()
method and calculated imitation from there.
Since I saw from the example that the custom_loss method has an input argument called policy_loss
, so I assume this is passed in from the previous centralized critic loss, therefore loss_fn
will be called before custom_loss
?