Loading pre-trained BC policy weight for tunning with hyper-parameter optimization

Hi, I’m relatively new to RL, but thanks to RLlib, getting bit familiar.

I am doing a small project on “Cartpole-v1” where pre-training via BC.
And using pre-trained weight, I want to initialize the PPO agent and continue to learning using below hyper-parameter tunning.

But, it seems like pre-trained weight cannot be used to intialized the policy network.

hyperparam_mutations = {
“lambda”: lambda: random.uniform(0.5, 1.0),
“clip_param”: lambda: random.uniform(0.01, 0.5),
“lr”: lambda: random.uniform(1e-6, 1e-4),
“train_batch_size”: lambda: random.randint(200, 300),
“kl_coeff”: lambda: random.uniform(0.01, 0.5),
“entropy_coeff”: lambda: random.uniform(0, 0.2)
}

pbt = PopulationBasedTraining(
time_attr=“training_iteration”,
perturbation_interval=50,
resample_probability=0.25,
# Specifies the mutations of these hyperparams
hyperparam_mutations=hyperparam_mutations,
)

stopping_criteria = {“training_iteration”: 2}

model_config = dict(
fcnet_hiddens=[512, 512, 256, 256],
fcnet_activation=“tanh”)

checkpoint_path = ‘C:/Users/kjd/Desktop/BC/pre_trained_checkpoint’

tuner = tune.Tuner(
‘PPO’,
tune_config=tune.TuneConfig(
metric=“episode_reward_mean”,
mode=“max”,
scheduler=pbt,
num_samples=2,
),
param_space={
“env”: ‘CartPole-v1’,
“disable_env_checking”: True,
“kl_coeff”: 0.5,
“num_workers”: 8,
“num_cpus”: 1, # number of CPUs to use per trial
“num_gpus”: 0, # number of GPUs to use per trial
“model”: model_config,
“clip_param”: 0.2,
“lr”: 1e-4,
“train_batch_size”: tune.choice([200,300]),
“entropy_coeff”: 0.01,
“restore”: checkpoint_path
},
run_config=train.RunConfig(
stop=stopping_criteria,
name=‘BC_test_cartpole’,
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True, # Save a checkpoint at the end of training
)
)
)

results = tuner.fit()
best_result = results.get_best_result()

Does ray.tune support weight initialization by given checkpoint?

I did what you are asking about a while ago. I think it was on Ray 2.4 or so, so things may have changed since then with the new API. In place of creating the tuner object with your current statement, try using tuner = Tuner.restore() before calling tuner.fit(). It should bring in all the various config params you have specified from the checkpoint.