I asked this question on Ray’s Slack channel and now I am transcribing the Thread here so it’s available to everyone. Thanks to @sven1977 and @rliaw for answering my question.
Bruno Brandao:
Does anyone know how can I pre-train a model with Behaviour Cloning and then load it into PPO for further training? I understand that they are of different kinds (offline and online), but this will be addressed. The BC training will use offline dataset, and the PPO will use a proper environment.Richard Liaw (ray team):
cc @Sven Mika maybe we can create an example? seems also related to what @Michael Luo is working onSven Mika:
Sure, hey Bruno Brandao. There is a BC learning test inside rllib/agents/marwil/tests/test_bc.py, where you can see how to train a BC agent.
You could then store the BC agent’s weights (trainer.get_weights()) and re-load these (trainer.set_weights(…)) into a new PPO agent (using the same model!). Would that work?Sven Mika:
I’ll create an example script.Bruno Brandao:
Hi, thank you for responding so fast. Yes, i think the get.weights() and set.weights() might work, I’ll try it out and come back with what I find. The example script would be amazing, I looked around, there are other people with the same question/issue.
The solution works and it is very simple. Here is an example of the code, it can be run step by step in a notebook to see the outputs and compare.
import ray
ray.init(ignore_reinit_error=True)
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
BC_DEFAULT_CONFIG['env'] = 'CartPole-v0'
BC_DEFAULT_CONFIG['model']['vf_share_layers'] = False
BC_DEFAULT_CONFIG['model']['fcnet_hiddens'] = [32,16]
bcloning = BCTrainer(BC_DEFAULT_CONFIG)
BC_DEFAULT_CONFIG
Do the BC training you wish in here, then you get the weights.
bcweights = bcloning.get_weights()
bcweights
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
DEFAULT_CONFIG['env'] = 'CartPole-v0'
DEFAULT_CONFIG['model']['vf_share_layers'] = False
DEFAULT_CONFIG['model']['fcnet_hiddens'] = [32,16]
ppotrainer = PPOTrainer(DEFAULT_CONFIG)
Get weights for the ppo just to see/show that they are different.
ppotrainer.get_weights()
Put the BC trained weights in the ppo trainer.
ppotrainer.set_weights(bcweights)
Check the ppo weights again, you’ll see that they match, now the trainer can start the PPO training.
ppotrainer.get_weights()
The thing to pay the most attention to, is to make sure the configuration of both models match, otherwise the weights wont match as well.