Best practice for multi-stage training workflow

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I often have multi-agent workflows that involve multiple stages, for instance

  1. train one agent against a hard-coded policy controlling the second agent
  2. freeze the first agent’s weights, and train the second agent against that
  3. unfreeze the first agent, and train both agents together

Is there a preferred way of doing this, especially together with Ray Tune? The options I can think of are

  1. just use a single call to and a standard RLlib Trainer, and use a callback to change the configuration mid-training. Not sure if this even works, or if policies have to remain the same throughout training, but just to mention it.
  2. use completely separate calls. Take the final checkpoint of the first call to set weights at the beginning of the second call, etc.
  3. use a custom Trainable the includes all the different stages, e.g. something like
def custom_multi_stage_trainable(config):
    trainer1 = DQN(...) # set only agent_0 as a policy to train, and put a hardcoded policy into the policy map for agent_1
    for _ in range(stage_1_iters):
        yield trainer1.train()
    stage_1_weights_agent_0 = trainer1.get_weights("agent_0")
    trainer2 = DQN(...) # set only agent_1 as policy to train, but leave agent_0 policy default
    for _ in range(stage_2_iters):
        yield trainer2.train()

Options 2 and 3 are pretty similar I think, and both work. Option 3 seems slightly cleaner if you’ll only ever want to tune hyperparameters for the whole workflow at once. Option 2 would make it easier to split training into separate runs, so you could re-run stage 2 multiple times from the same stage 1 checkpoint without rerunning stage 1. But I wonder if there are any other advantages or drawbacks to either approach that I haven’t thought of. Does anyone have experience with a similar workflow that they can share?

Also just wanted to check, if I yield trainer.train() inside a custom trainable function, is that 100% the same as just passing the Trainer directly to Ray Tune? I.e. do I get all metrics etc in exactly the same way?

Thank you!

1 Like
  • For generally training adversarial policies, have you looked at Alpha Star?

  • You have hit something here that has not best practice. You policies will likely take turns in exploiting each other’s weaknesses. Depending on how long you train them, this might even result in back-and-forth situations. Is there a reason you train in such an alternating fashion and not together?

  • Yes, tune does not even “know” what algorithm it is training. It simply calls Trainable.train() and sees how it scores, fits parameters, watches resources etc.

  • I’m not really thinking specifically about adversarial settings here, maybe I should have written “train with” instead of “train against”. I hadn’t thought about Alpha Star as being applicable to what I’m doing, but I’ll have another look into it, thank you.

  • One example where we do this is “chicken-and-egg” type situations. E.g. if a task can only be solved by two cooperating agents together, it can happen that agent 1 needs to act sensibly so that agent 2 can learn, and vice versa. Starting training with one skilled agent for a while can help training get off the ground much faster. (This is an exteremely simplified example, I know in that simple situation there would be plenty of other things you could do.) Or if we wanna guide agents toward one of many equilibria in the system, we could train initially with an agent that acts (even vaguely) like that equilibrium we want to get to, to bias the system in that direction. Or suppose we have a small amount of offline data we wanna train on first, before doing online training, that would result in a similar workflow - fixing one agent first was just an example, I’m more generally curious workflows that include more than one training loop in sequence.

  • Got it, that’s good to know - thank you. There isn’t an easy way to still checkpoint in that case, is there? I imagine I would have to use the trainable class API instead of a function for that to work, right?

Just a quick note here for anyone who comes across this and wants to do something similar, if you want to follow the code snippet above and wrap an rllib algorithm inside a function, make sure you call algorithm.stop() at the end as well, otherwise some necessary cleanup might not happen. Ideally you’d want to wrap it in a try…finally block:

def custom_trainable(config):
        algorithm = DQN(config)
        for _ in range(...):
            yield algorithm.train()

That way, stop() gets called even if the trial is ended early.