Population based training (PBT) with checkpoint restore

Hi,

I’ve been working on curriculum training where each level is initiated manually, using the best checkpoint from the previous level as a starting point (this way I can reset things like LR annealing, noise schedule, and change other HPs at each level). I recently got this working by using custom callbacks, deriving my own class from RLlib’s DefaultCallbacks and overriding its on_algorithm_init() method to load the desired checkpoint (thanks to help from the Ray team!). I am also using this with PBT for hyperparam search.

This all works great for the first training level, where it starts from a randomly initialized NN. The problem is that I feel like PBT is doing something weird at the perturbation intervals, in that it calls the on_algorithm_init() method every time one of the trials hits the perturb interval. In higher learning levels that begin with a checkpoint, the effect is that every perturb cycle loads the baseline checkpoint into the NN, and the next few iterations start training it all over. Therefore, it seems that each perturb cycle is erasing all the learning that happened in the previous cycle.

A simple workaround that occurs to me is to have my on_algorithm_init() load the checkpoint only on the first pass (within each trial). However, I don’t know what that if test would look like. I don’t know how to make it differentiate that it’s called due to starting a new trial vs being called as the result of a perturb cycle. Also, I’m now confused about how Tune/PBT works without a checkpoint loaded. It is apparently initializing a new algorithm with a new NN model at each perturb cycle, so wouldn’t that model get newly initialized weights also, instead of carrying over the ones already trained?

Thanks for any insights.

Why are you manually restoring your checkpoints if you are using PBT?
PBT is already a population based method, where it will pause and pick the most promising off-spring to continue training. When a trial is restored for continued training, it will try to restore the latest checkpoint when it was stopped last time.

PBT can be used to search for optimal HParam settings if each Trials are independent of each other.
I feel like you are trying to wrap a curriculum stack, which already has a bunch of custom start/stop logics inside another PBT search which also tries to start/stop trials.

Things probably won’t work properly like this.

@gjoliver I may indeed be approaching this problem incorrectly. Here is my thought process. I am training an automated vehicle (simulated) and find that a single, big training session isn’t successful in learning all the skills it needs. So I wanted to use curriculum learning to feed it one skill at a time. At the end of each curriculum level I get a checkpoint that represents a partially trained agent. I want to use that checkpoint as the starting point for the next level to build upon. But my thought is to train each level separately, one at a time, since I have no presumption that the same HPs that worked well in one level will also be good for the next level. In particular, I feel that noise injection probably needs to start relatively large in each level, then taper off as it progresses. Will PBT reset this (and the step counter that controls its reduction) at each perturbation?

I have also have discovered that PBT can do wonders for HP search. So I figured to run each training job (curriculum level) as a PBT job. Thus, my curriculum learning is a fully manual process, nothing is really being wrapped. It seems that PBT manages intermediate checkpoints just fine if it starts from a newly initialized model. But starting from a checkpoint seems to confuse it. In browsing the PBT code, it is not obvious to me where it might invoke my callback override, but that seems to be getting in the way.

Is there a way to make PBT do everything I need, or maybe some other approach that avoids such a conflict? Many thanks in advance.

I see. so you are structuring your job as a sequence of PBT runs.
And you are manually kicking off these PBT runs as a curriculum. that makes a lot more sense :slight_smile:

there are probably ways you can hack this, for example, not sure if you can look at algorithm._counters in your callback and see if it is a fresh run.
also, @justinvyu mentioned to me if you are working with a single node cluster, you can look at the trial folder for all the perturbations happened up till a specific run.

I also want to give some random feedbacks about the high level approach. Teaching an agent skill one by one usually requires some special reward or penalty design to make sure the agent doesn’t suffer catastrophic forgetting while trying to acquire new skills.
It may be easier to train for all skills at once and just spend more compute on it.
Just some random thoughts.

@gjoliver thanks for the feedback. I’m not sure I can get to algorithm._counters since I don’t instantiate an Algorithm object anywhere, but I’ll poke around and see what creativity I can come up with. I am working with a single node, so I’ll also look into the perturb files idea.

nice!
in the callback, you do have the algorithm object.

@gjoliver I’ve tried looking at algorithm._counters, but it doesn’t contain any info. Whether running basic random HP search or PBT, my print statement in on_algorithm_init() results in the following: algorithm._counters = defaultdict(<class 'int'>, {}) at all points in the run history. I do see in the Algorithm code that it is being populated, but only in the training_step() method and other places after init() and setup() are called. Since the on_algorithm_init() callback is called from setup, there is no counter info at that time. So that’s not gonna work. But thank you for the idea.

@gjoliver FYI I have worked around the issues and now have PBT schedule able to start from a policy checkpoint (just pulling out the NN weights), and continue learning for many perturbation cycles. It also has the ability to automatically increment curriculum tasks if desired. The key was to build a component that would pass global timestep counters across all threads (after a perturb, the counters get reset due to new Algorithm objects being instantiated). I tried building a Singleton object to hold that info, but even that doesn’t work across worker threads, each of which builds its own master Singleton. So I built a pseudo-singleton that stores counters in the file system, and allows multiple instances to read/write (thread-safe). Ugh! you say! What a performance hit! Yeah, well, it’s only one tiny file, and it achieves my desired ends. I haven’t noticed a significant change in the wall clock pace of my training, so it’s good enough for now. I’m sure there’s a way to use Ray actors to pass that info via RPC, but I didn’t want to take the time to figure all that out. :smiley:

very very nice!! it’s fun working around all the problem, isn’t it :slight_smile:
just as a super quick tip, you can use a global named actor Named Actors — Ray 2.3.0 to hold all the cluster-wide global state.
keep up posted with your curriculum learning agent, now that the real fun begins with reinforcement learning :slight_smile: