[Question] Policy distillation methods

Hi everyone,

I have a question about how to implement policy distillation in rllib. Similar to the scheme presented in divide and conquer rl, I have multiple agent learning distinct partitions of the context space in a CMDP (or task in the multitask rl problem). I want to distill these local learner into a central policy as I am training the local learners.

Please note that the central policy would not be interacting with the environment and is being updated by a imitation learning pass on the samples collected by the local learners.

I have seen a similar discussion here, but that idea only works for one teacher policy.

Any suggestions would help.

Hi @msnardakani ,

Here is a solution of the top of my head that you could try:

  • Create a multi-agent env where each agent is dealing with only one part of the state space. Let’s call them teacher1, …, teacherN.
  • in your specification you have to define N+1 policies, one for each teacher and another one for a student, i.e. {"t1", ..., "tN", "s"}.
  • All policies are trainable.
  • Create a custom policy for agent “s” and in its postprocess_trajectory augment the observation and action from all the teacher agents and have the policy train on the IL for the distillation part for this particular policy.

One thing that I am concerned with is how this approach different from just scaling up the environment sampling by using more than 1 rollout worker and have each rollout worker deal with a distinct part of the state space. You can slice the state space based on worker_idx. Then the data that you collect from each part of the state-space can be consolidated to train a single policy. It should in principle have the same effect as imitation distillation. Unless you imitate to not only imitate the final actions but also the intermediate activation. Then I assume distillation can be more powerful because it gives you more training signal. If the imitation of activations is the goal then the solution that I provided above won’t work, because you don’t have access to the other agent’s policies to construct the obs → activation mapping for the central student agent.

1 Like

Thank you for the answer. Actually, I started with the same sketch in my mind. Here are a couple of challenges and potential directions:

  • Ideally, we want to be able to apply certain constraints from the distilled policy to the local policies as well. For instance, similar to what we have in PPO the likelihood ratio would be computed between the central policy and the local policies. A possible solution would be to augment the outputs from the central policy as well in the postprocess_trajectory of t1, …,tN for this the distilled model should be shared with the local policies.

  • Shared network is also shown to be effective in this scheme. basically the local policy output is botzman distribution of a linear combination of the distilled network and a task specific network. This would also another reason to make the central policy untrainable and do all the updates when the local policies are being updated.

Regarding the concern at the end, assigning the state space partitions to workers is only the same when the state space is augmetned with the task signal or the context variable in CMDP when in general those information are supposed to be fed only to the distilled policy.

My implementation is composed of a multiagent environemnt and a shared central model. Currently, I am verifying my results and would update this thread.

1 Like

I see. so each agent is solving its own task. During distillation, the policy is conditioned on the task to imitate the actions of those agents that were trained on that task.
Your direction also makes sense. However, If central distilled policy is copied into all policies, you need to somehow share the gradient back-propagation across them. How are you doing that in RLlib? I don’t know if we support that as of today. We have a new design for neural network modules as well as trainer to address this issue which we are working on.

I started with the shared_weights_model example. Similar to the example I have defined the shared model globally and used it in the local policies. So ideally every policy both accesses the distilled model and updates it.

1 Like

Very interesting, I always thought this was impossible in RLlib, but good to know that there is actually a way to do it. There is a downside to this approach tho. The scope of the global variable is limited to the process that is running these models, so if there are parallel processes updating the model (e.g multi-gpu) then this method would probably not work. I am glad to say that we are currently working on a new API design for models and multi-gpu training (called RLModules and RLTrainers) that as a bonus point addresses this exact headache of sharing communication NNs between agents. Please keep this thread updated as I am very curious how and whether you end up with a solution. It would be awesome if you could contribute an example of this to rllib once you an e2e solution.