Partial freeze and partial train

Dear RLLib team,

Thanks for your great work in RLLib, we all much enjoy it! However, we met a problem in how to realize this feature, we appreciate it if can help to advise us of your suggested solution in RLLib.

Problem Description:
This problem requires one more layer of transfer learning on top of another trained policy. The trained policy should be frozen during training and the new layer will be trained.

Our Naive Solution:
Our plan is to use your custom_train_workflow related features. During custom training workflow, we select the part of the model’s parameters to freeze, and some other parts of parameters to train.

Thanks for your help let us know if our solution is correct and follows the rllib philosophy. Or you already have a much easier solution for that.

Hi @yiwc ,

If you implement your model yourself, i.e. with the ModelV2 API , you can simply put a tf.stop_gradient() in your forward pass function.

Otherwise, you can update your policy with a new apply_gradients_fn that only applies the gradients to your one layer and leaves the other ones alone.

If you have questions on how to do this, I will be happy to answer them.


Hi @yiwc,

If you are using torch, you could write a function to freeze the layers of interest by setting requires_grad=False on those layers parameters.

If you are using tf and keras you can set the layer.trainable=False

With either framework, could then create a new trainer that apply this function similar to the method described in this post:

Hi @arturn @mannyv ,

We appreciate your immediate reply!

Yes we will try apply_gradients_fn, see how far we can go.

Thanks again, and have a good day!

Hi @arturn,

Thanks for your advice. Now we are trying to fine-tune a model from a loaded trained model. Where do you think we should put the load model code in?

we thought of a few possible solutions

  1. put in the custom train execution plan, before train we load the pre-trained model first.
# some brief pseudocode just for idea
def execution_plan(xxx):
  1. we can also load the pre-trained model before the tune function starts.
# some brief pseudocode just for idea

Appreciate your help and advice if these are recommended solutions~

Hi @yiwc ,

Glad we could help.
If you want to load a complete model of a previously trained policy, the easiest way is to call the restore method of your Trainer. From the docs:

agent = ppo.PPOTrainer(config=config, env=env_class)

Does this work for you? There are of course other ways and more elaborate solutions.
I am sure @mannyv has more to offer :slight_smile:


1 Like