Retraining a loaded checkpoint using with different config

  • High: It blocks me to complete my task.


How can you reload a checkpoint and start retraining with an updated config? A typical use-case would be wanting to increase the batch size mid-training.

In Ray 2.0 using a Tuner allows you to restore the trial, but I can’t see a way to update the config once restored.


Hey @davidADSP,

Could you share some more about your use-case? In particular, I’m wondering what signals you’d look for that would make you want to update the config mid-training (and if you could just define them to update directly in the Trainable).

I don’t think there’s currently a way to update the config when restoring an Experiment.

cc @justinvyu @kai

Hi Matthew - thanks for the reply.

For example, you might want to adjust the batch size if you saw performance was flattening off in a training run.

I suppose it’s a more general question about how to start a Tune fit process from a partially trained model. I think there’s a common use-case where you have a decent set of weights and want to see if you can squeeze more performance out of the model by continuing to train it using new hyperparameters. For example using Tune with RLLib, I wouldn’t want to always have to start training from scratch, when I already have a policy network that is exhibiting useful behaviour (trained using a prior call).

I think this was possible using Ray 1 and (using the restore argument), but I’m not 100% sure - I’m fairly new to Ray still! Could the same behaviour be achieved using I know there’s a restore method, but that restores everything I think (including the config) - whereas, I would just like to restore the trained network checkpoint as a starting point, but choose a different config.



1 Like

Hi David,

Yes it’s not possible to use the restore method because it resumes the run with no changes to the config. However, you should load the checkpoint you want to restore from and save the weights. For Tensorflow this can be done like this (Ray 2.0.x):


tune_config = {“env”: “XXX-v0”, # put your env here
“model”: {“custom_model”:“LightModel”},
“framework”: “tf2”,
“eager_tracing”: True}

default_config = impala.DEFAULT_CONFIG.copy()

config = default_config | tune_config

trainer = impala.Impala(config=config)

ckp_num = “1000” # or whatever…

weights = trainer.get_policy().get_weights()

trainer.get_policy().model.base_model_actor.save_weights(“light_model_actor_” + ckp_num + “ckp.h5")
” + ckp_num + “_ckp.h5”)

Then subclass the algorithm you want to use and define custom model - might just be the default model:

class Impalaalgo(impala.Impala): # inspired by another issue but can’t remember the number
def init(self, config, **kwargs):
super(Impalaalgo, self).init(config, **kwargs)
“”" Needs full path here!“”"
_cwd = os.path.dirname(os.path.abspath(file))
actor_weights = _cwd +“/light_model_actor_180_ckp.h5”
critic_weights = _cwd +“/light_model_critic_180_ckp.h5”
self.get_policy().model.import_from_h5([actor_weights, critic_weights])
self.workers.sync_weights() # Important!!!

def reset_config(self, new_config):
    """ to enable reuse of actors """
    self.config = new_config
    return True    

And then do a new with another config like this:

tuner = tune.Tuner(
name=“Impalaalgo_run2”, …
param_space =my_new_config,…

I hope this bring you forward.

BR Jorgen


I noted that the indentation of the class failed and in the above case the ckp numbers should be 1000 and not 180 when reloaded in the subclass.

Hi Jorgen,

Many thanks - I’m not sure this is going to help in my case as I’m using torch and with multi-agent PPO, so some of the functions (e.g.get_policy) return None for me.

I know I can get the agent restored using this code:

agent = ppo.PPO(config = config)

But I guess this also restores the param_config? It feels that the parameters that define the model (e.g.['model']) are intertwined with those that are purely about the training process (e.g.['train_batch_size']).

In general, I guess I’d love to be able to write this:

tuner = ray.tune.Tuner(
          param_space=config.to_dict(), ...

where the param_space is just talking about how I want the agent to continue being trained, rather than re-defining the agent model itself.

This currently isn’t possible as agent is an instantiated PPO object rather than a Trainable class. I see what you’re doing with the code above - essentially creating a new class that has those weights loaded by default, but I can’t seem to get it to work for my specific use-case.

It would be really helpful to have a way to pass in an existing agent to the function, rather than a Trainable class and hide away the logic that is converting from the trained object to the fresh Trainable class.

Last thing - if I didn’t use, is what I’m looking to do possible with tune.Experiment(restore...) where I pass in a checkpoint file? I’m not clear on what the difference is between, tune.Experiment, and tune.run_experiments.

Thanks again for your help!


Hi @davidADSP

In terms of multiagent I suggest you take a look at this issue:

This should enable you to get something out of get_policy() by specifying which policy you want.

Well, I don’t think what you would love is possible (tuner = ray.tune.Tuner(agent, …). But there are ways around changing the config and then do a new with the new config from a previously saved checkpoint from tuner without loosing anything.

I not proficient in Torch but I think Torch models will have similar functions to save and load weights as in Tensorflow. This topic is sparse and fragmented in the official Ray documentation in general. It’s a petty because it deters potential new users from getting started with this great RL framework. I’ve made a toy example on my github that does what you want (two times tuning with different config) and take it all the way to production without the need to carry the overhead of Ray in the end and thus only relying on the ML-framework (Tensorflow :thinking: ). At the end of the day that is probably what most users need.

Note that during the second “lr” is changed as well as “training_iteration”. lr=0.0 is to prove that the weights loaded initially are the previously trained/saved in the best checkpoint as they do not change when lr=0,0.

Take a look here:

In terms of your last point I’m not sure but my feeling is no, sorry.