Tune as part of curriculum training


I have been frustrated training an agent for a challenging environment. I want to use curriculum training, but have had to kludge the whole curriculum together into a single Tuner run, which quickly becomes unwieldy. Maybe I’m thinking about the whole use of HP tuning incorrectly, but here’s what seems logical to me:

  1. Specify a new model structure & initialize (good HPs are unknown)
  2. Use Tune to search the HP space over “lesson 1” of the curriculum
  3. Save the best model from step 2 (presumably a checkpoint) as the basis for lesson 2
  4. Use Tune to search the HP space over “lesson 2” of the curriculum
  5. Save the best model from step 4 as the basis for more challenging tasks/lessons
    I have no reason to believe that the exact same HPs will be ideal on all lessons. In fact, I would expect, at a minimum, that learning rate would need to be adjusted for some lessons.

The problem is that Tune only seems to store bulky checkpoints that record everything about the current tuning session, so that none of the HPs , environment, or anything else can be altered for a future round of tuning. I feel like the above sequence would be a normal way of doing business, and therefore, flexible and transparent checkpoint handling would be a top priority. Since they are not, I must conclude this is not the way most people think about training an agent. What is a better approach?

On the other hand, if this does sound like a solid approach, I would appreciate some guidance in using the checkpoints accordingly. It feels like Ray.Train provides the kind of checkpoint I want, which is nothing more than the raw NN weights. But I can’t figure out how to extract these from a Tuner checkpoint or how to inject partially trained weights into a new Tuner job.

At the very least, If making this work is possible by any means, it would certainly be nice to see it documented in the Tune user guide somewhere. The Ray docs generally tend to only show elementary examples, and could stand to include some more real-world complexities.

Thanks in advance for any advice!

Hi @starkj,

A few clarification questions:

  1. Are you training RL agents? Are you using Tune + RLlib?
  2. Are your tasks configurable within a single environment? Or do you need to swap environments out?
  3. What is your HP search strategy for each round? Random/grid search? Or is it okay to have something like a bayesopt searcher? Do you want each round to be independent from the last based on this assumption?

I have no reason to believe that the exact same HPs will be ideal on all lessons.

I will have a few suggestions depending on your setup!

@justinvyu yes this is for RL agents using Tune + RLlib. For now, it’s just a single agent, but eventually I’ll want to handle multiple agents with different policies.

As for environments, I would prefer to swap environments for successive tasks, but so far have been forced to try gradually adding complexity to a single one as part of a single, long trial. It’s getting kinda ugly.

I have been using a random HP search, but I am certainly open to suggestions on another approach, just haven’t dug that deep into it yet.


Got it, thanks for clarifying!

I have few suggestions:

Suggestion 1: Just use a single env that swaps tasks

This is pretty much what you’re doing now, and we have an RLlib example here: Advanced Python APIs — Ray 2.3.0

Suggestion 2: Use the PBT scheduler + a single env that swaps tasks

The part of your use case that’s a bit special is the “all agents should use the best weights and resample HPs” on each round of the curriculum. This is similar to what Population Based Training does – the low-performing trials will copy over hyperparameters and model weights of the best trials and then perturb the hyperparams for exploration.

Combined with the single env that swaps tasks at a certain number of timesteps, this would get close to what you’re looking for.

See these two guides for PBT example usage:

Suggestion 3: Implement your own “CurriculumScheduler”

Finally, to get the exact behavior you’re looking for, implemeting your own TrialScheduler is probably the best bet.

You can pause trials every N timesteps, then distribute the best model weights to each trial + sample hyperparameters for the next round. If this is what you’re looking for, I can give some more tips to start out.

I may have some more suggestions involving multiple Tune experiments, but still fleshing out some details there.

Could you also give some more feedback on your frustrations with using AIR checkpoints?

It feels like Ray.Train provides the kind of checkpoint I want, which is nothing more than the raw NN weights.

What are the Ray Train checkpoints you’re mentioning that only contain NN weights? Ray Train should use AIR checkpoints, too.

Tune only seems to store bulky checkpoints that record everything about the current tuning session, so that none of the HPs , environment, or anything else can be altered for a future round of tuning.

Tune checkpoints basically dump model weights and stuff into a directory, then packages it as a Checkpoint object. What info is currently hard to access in the checkpoint?

Thank you!

@justinvyu This looks like some great guidance. I will eagerly start working through your suggestions. I just discovered the PBT capability in Ray yesterday, and it looked like a good avenue. Ray’s capability far outweighs the documentation’s ability to convey it all :smiley: so it’s difficult to get one’s hands around it. The docs tend to only show only extremely simple or abbreviated examples, so something like this is just doesn’t get touched in any coherent way that I’ve found.

1 Like

I probably just misunderstand the usage. But when I run a Tune job and get, say 10 trials, I can look at the rewards plots (or whatever criteria) and find the best performing trial. Then I want to take that trial’s best checkpoint and use it (the policy weights) as the starting point for another Tune run, with a different env config and possibly different HPs. However, because that checkpoint has the env and all the HPs baked in, it’s not clear to me if/how I can alter any of them before calling tuner.fit() again. And I don’t see a way to pull the raw policy weights out of that checkpoint. However, if I make a checkpoint with Train the directory has totally different contents, and the policy itself is right there, apparently easy to pull out and reuse.

1 Like

@justinvyu I have implemented your first two suggestions, above, and am happy with the result, as far as it goes. Two limitations I still see:
A) HPs that can be scheduled (e.g. noise level, LR) can’t easily be scheduled per-task, rather just one long schedule that spans all tasks (and specified by time steps, which I can’t always predict well). So I don’t know how to restart these schedules with a new task.

B) If I run into a problem getting the agent to learn, say, task 3, but it nails tasks 0, 1 and 2 well, I don’t want to have to wait for my Tune job to keep churning through those first tasks in order to work on the problem at hand. I’d like to just start a new job with the task 2 solution as the baseline agent. It seems that any fully automated curriculum learning assumes one run that solves all the tasks and produces a winner. Very optimistic :slight_smile:

So, a couple more questions:

Q1. I’m not sure what you’re getting at with suggestion #3. Would this be for me to subclass PopulationBasedTraining scheduler and add capability to modify a “task” config param that tells the environment when to advance to the next difficulty level?

Q2. I would still very much like to grab a checkpoint, pull out the model weights, and then use those weights as the starting point for a new Tune run. So my thought is to just Tune task 0 until it achieves the reward desired. Then take the best checkpoint from that and run a Tune session using task 1 difficulty; in this way noise schedule, LR schedule, etc, could be reset and customized just for that task. Then take the winning checkpoint from task 1 and begin a new Tune job for task 2, and so on.

Q3. Sanity check - I definitely feel like I’m asking for something unusual above. Why is that? I feel like this would be a common need.

Thanks again.

Q2 is what I was originally planning to provide as a suggestion, but it’s not so straightforward to start an RLlib experiment from a checkpoint.

With regular Tune functions, you can do something like:

def train_fn(config):
    ckpt = config["first_checkpoint"]
    state_dict = ckpt.to_dict()["model_state_dict"]

tuner = Tuner(train_fn, param_space={"start_from_checkpoint": Checkpoint.from_directory("/path/to/prev/checkpoint"})

However, RLlib doesn’t seem to offer such an API.
cc @arturn Do you know of any APIs that would enable resuming from a checkpoint? Ex: for fine-tuning a policy?

Thanks @justinvyu . I figured it is not straightforward. I’m up for modifying some Ray code and making a contribution out of it. But it would be helpful if someone could point me in the general right direction. I’ll look in the Algorithm class for something analogous to the train_fn() example you give above.

One approach can be seen here.

Does that suit your needs?
There is no train_fn in RLlib that would be a good place to put this logic IMO.
If you handle your training in a Python script anyways however, I’d think that the above script gives you a good starting point.
Lmk how it goes!

@arturn thanks! This example seems to be very close to what I need. I will play with it a while and report back.

@arturn your examples all use PPOConfig.build() and then ppo.train(). I really want to do this in a Tuner job. The only documented way that I’ve found to set up a Tuner with PPO is something like
tuner = Tuner("PPO", param_space = config.to_dict()...)
These examples never show the config’s build() method being used, only passing in the str name of the algo instead of an algorithm object. How can your examples be modified to pass into Tuner? Maybe this is where I need to make a custom mod to something in the Algorithm stack?

I think one way to do this is by callbacks.
Notice that in the example Artur gave you, what we essentially do is to load a pre-trained policy checkpoint into the algorithm, and use that as the baseline for further Training.

You can actually do that with a callback like on_algorithm_start():

You have access to the algorithm, and you just need to create a policy from the checkpoint you want to resume, and do algorithm.add_policy(policy=<your reloaded policy>):

Give this a try. Your project sounds really exciting actually :slight_smile:

@gjoliver this feels really close to working! However, I hit a brick wall. When my Callbacks code attempts to load a previous checkpoint, the whole Ray run just hangs. I don’t believe it is when the checkpoint loading statement is getting executed, but whenever it exists, the thing stops somewhere in TrialRunner.step(). I have sprinkled several print statements throughout Ray code to pin it down; this is as far as I’ve traced it down so far. Funny thing is, my on_algorithm_init() is never being called (at least its print statements aren’t reflected in stdout, but maybe there’s an i/o buffering lag?). Here’s my code:

from typing import Dict
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks

class CdaCallbacks (DefaultCallbacks):
    """This class provides utility callbacks that RLlib training algorithms invoke at various points."""

    _checkpoint_path = None #static class variable that will be seen by Algorithm's instance

    def __init__(self,
                 legacy_callbacks_dict: Dict[str, callable] = None
        print("///// CdaCallbacks __init__ entered. Stored path = ", CdaCallbacks._checkpoint_path)

    def set_path(self,
                 path   : str,
                ) -> None:
        """Provides an OO approach to setting the static checkpoint path location."""

        CdaCallbacks._checkpoint_path = path
        print("///// CdaCallbacks.set_path confirming that path has been stored as ", CdaCallbacks._checkpoint_path)

    def on_algorithm_init(self, *,
                          algorithm:    "Algorithm", #TODO: does this need to be "PPO"?
                         ) -> None:
        """Called when a new algorithm instance had finished its setup() but before training begins.
            We will use it to load NN weights from a previous checkpoint.  No kwargs are passed in,
            so we have to resort to some tricks to retrieve the deisred checkpoint name.  The RLlib
            algorithm object creates its own object of this class, so we get info into that object
            via the class variable, _checkpoint_path.

            ASSUMES that the NN structure in the checkpoint is identical to the current run config,
            and belongs to the one and only policy, named "default_policy".

        print("///// CdaCallbacks.on_algorithm_init: checkpoint path = ", CdaCallbacks._checkpoint_path)
        # Here is an old dir from 12/17/22. It only contains one file, named checkpoint-600, so the format seems incompatible.
        ckpt = "/home/starkj/ray_results/cda0-solo/PPO_SimpleHighwayRampWrapper_53a0c_00002_2_stddev=0.6529,seed=10003_2022-12-17_10-54-12/checkpoint_000600"

        # Here is a new checkpoint made on 3/10/23 (on the Git branch tune-checkpointing, commit 4ae6)
        ckpt = "/home/starkj/ray_results/cda0/PPO_SimpleHighwayRampWrapper_05f87_00003_3_clip_param=0.1785,entropy_coeff=0.0051,stddev=0.4683,kl_coeff=0.4298,lr=0.0001_2023-03-10_11-17-45/checkpoint_000270"
        initial_weights = algorithm.get_weights(["default_policy"])
        self._print_sample_weights("Newly created model", initial_weights)

        ### When this line and below is uncommented, then Ray hangs!
        temp_ppo = Algorithm.from_checkpoint(ckpt)
        print("      checkpoint loaded.")
        saved_weights = temp_ppo.get_weights()
        self._print_sample_weights("Restored from checkpoint", saved_weights)
        verif_weights = algorithm.get_weights(["default_policy"])
        self._print_sample_weights("Verified now in algo to be trained", verif_weights)

A secondary problem is that my main program that creates the Tuner object is injecting the checkpoint path into the CdaCallbacks class variable so that info can be passed to Algorithm’s object. But it doesn’t get passed. Python language says it should. Maybe there’s a namespace conflict that I’m not aware of? Anyway, I can work around this by hard-coding the checkpoint path into my CdaCallbacks class - it’s ugly, but it works.

Any ideas? I could show you my output log with all the print statements if that’s of interest. This is Ray 2.3.0. Thanks.

Update: I figured out how to avoid python stdout buffering by runnint python -u, and now see that the problem is indeed somewhere near where Algorithm.from_state() is trying to create the new PPO algorithm object. Continuing to follow the trail…

@gjoliver , @arturn My debugging has led me to this point, and I am now stuck. My tune program only calls for 1 rollout worker, 1 evaluation worker, 1 environment per worker, trying to keep it simple. I have 16 cpus, so may get multiple simultaneous trials.

By inserting print statements in several places within the Ray 2.3.0 code, I’ve been able to trace the problem down to a call to ray.wait() that never comes back, made from FaultTolerantActorManager.__fetch_result(), which results from a call to WorkerSet.add_workers() from Algorithm.setup(). The output log (both stdout & stderr) shows that the call to my custom callbacks on_algorithm_init() method never completes. Its statement
temp_ppo = Algorithm.from_checkpoint(ckpt)
goes down a rabbit hole that ends in the infinite timeout call mentioned above.

You can look at my code and the log file at GitHub - TonysCousin/cda0 at ray-hang. A couple notes to easy your navigation:

  • The main program is cda0_tune.py
  • The custom callback is in cda_callbacks.py - note that this hard-codes the checkpoint file (passing a name into the CLI arg list doesn’t do anything)
  • A log of the current run is in RAY_HANGING_LOG.txt
  • The checkpoint it is trying to read was built just a couple days ago with this same code, but starting without a checkpoint as input. This checkpoint represents a well-trained agent (for its first task).
  • In the CdaCallbacks code I have also substituted the existing “Algorithm” (line 30 and line 54) with “PPO”, but the behavior is no different.

Thank you very much for whatever insight you can provide!


Thanks for putting in the effort.
I’ve been playing with your script a little and it works on my end.
I can save a checkpoint and load it later.
Without the checkpoint, the script does not hang at __fetch_result?

@arturn, not sure what you mean “without the checkpoint”. If I set ckpt = None then errors are generated early in the run and trial manager is shut down. If I comment out the code for PPO.from_checkpoint(ckpt) then runs fine, but of course it misses the whole point of loading data.

I just pushed the checkpoint sub-dir that I’m testing with to Github. If you care to look at that, you can now do so.

I have finally figured out which of the many implementations of wait() is the one being used here (ray._private.worker.wait()), and was able to follow the rabbit hole a little deeper. This function hangs when it calls worker.core_worker.wait(). core_worker is a ray._raylet.CoreWorker, which is implemented in Cython, a totally foreign thing to me! So I’m now officially stuck.


Update: just for kicks I edited the calls to WorkerSet.foreach_worker() by adding timeout_seconds = 1.0 from WorkerSet.add_workers() and from WorkerSet._get_spaces_from_remote_worker(). This allowed a lot more code to execute, but it still died, in _get_spaces_from_remote_worker() because the variable remote_spaces was empty.

I can understand the logic of scanning all possible workers/actors to gather data from whatever is out there, but it seems, in this case, to assume that remote workers exist when they don’t. Everything I’m doing is on the same machine. I guess remote workers can be on the local machine, but I have only specified 1 rollout worker and 1 evaluation worker. I don’t fully understand all the interactions here, but it seems that maybe the code is trying to gather info where it shouldn’t be (?). What do you think?