How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
I often have multi-agent workflows that involve multiple stages, for instance
- train one agent against a hard-coded policy controlling the second agent
- freeze the first agent’s weights, and train the second agent against that
- unfreeze the first agent, and train both agents together
Is there a preferred way of doing this, especially together with Ray Tune? The options I can think of are
- just use a single call to
tuner.fit()
and a standard RLlib Trainer, and use a callback to change the configuration mid-training. Not sure if this even works, or if policies have to remain the same throughout training, but just to mention it. - use completely separate
tuner.fit()
calls. Take the final checkpoint of the first call to set weights at the beginning of the second call, etc. - use a custom Trainable the includes all the different stages, e.g. something like
def custom_multi_stage_trainable(config):
trainer1 = DQN(...) # set only agent_0 as a policy to train, and put a hardcoded policy into the policy map for agent_1
for _ in range(stage_1_iters):
yield trainer1.train()
stage_1_weights_agent_0 = trainer1.get_weights("agent_0")
trainer2 = DQN(...) # set only agent_1 as policy to train, but leave agent_0 policy default
trainer_2.set_weights(stage_1_weights_agent_0)
for _ in range(stage_2_iters):
yield trainer2.train()
...
Options 2 and 3 are pretty similar I think, and both work. Option 3 seems slightly cleaner if you’ll only ever want to tune hyperparameters for the whole workflow at once. Option 2 would make it easier to split training into separate runs, so you could re-run stage 2 multiple times from the same stage 1 checkpoint without rerunning stage 1. But I wonder if there are any other advantages or drawbacks to either approach that I haven’t thought of. Does anyone have experience with a similar workflow that they can share?
Also just wanted to check, if I yield trainer.train()
inside a custom trainable function, is that 100% the same as just passing the Trainer directly to Ray Tune? I.e. do I get all metrics etc in exactly the same way?
Thank you!