Does anyone know how can I pre-train a model with Behaviour Cloning and then load it into PPO for further training? I understand that they are of different kinds (offline and online), but this will be addressed. The BC training will use offline dataset, and the PPO will use a proper environment.
Sure, hey Bruno Brandao. There is a BC learning test inside rllib/agents/marwil/tests/test_bc.py, where you can see how to train a BC agent.
You could then store the BC agent’s weights (trainer.get_weights()) and re-load these (trainer.set_weights(…)) into a new PPO agent (using the same model!). Would that work?
I’ll create an example script.
Hi, thank you for responding so fast. Yes, i think the get.weights() and set.weights() might work, I’ll try it out and come back with what I find. The example script would be amazing, I looked around, there are other people with the same question/issue.
The solution works and it is very simple. Here is an example of the code, it can be run step by step in a notebook to see the outputs and compare.
import ray ray.init(ignore_reinit_error=True)
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG BC_DEFAULT_CONFIG['env'] = 'CartPole-v0' BC_DEFAULT_CONFIG['model']['vf_share_layers'] = False BC_DEFAULT_CONFIG['model']['fcnet_hiddens'] = [32,16] bcloning = BCTrainer(BC_DEFAULT_CONFIG) BC_DEFAULT_CONFIG
Do the BC training you wish in here, then you get the weights.
bcweights = bcloning.get_weights() bcweights
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG DEFAULT_CONFIG['env'] = 'CartPole-v0' DEFAULT_CONFIG['model']['vf_share_layers'] = False DEFAULT_CONFIG['model']['fcnet_hiddens'] = [32,16] ppotrainer = PPOTrainer(DEFAULT_CONFIG)
Get weights for the ppo just to see/show that they are different.
Put the BC trained weights in the ppo trainer.
Check the ppo weights again, you’ll see that they match, now the trainer can start the PPO training.
The thing to pay the most attention to, is to make sure the configuration of both models match, otherwise the weights wont match as well.