Medium: It contributes to significant difficulty to complete my task, but I can work around it.
In my experiments, I wish to pre-train some agents with self-play on an environment (call it env a) before switching to training on a new environment (call it env b) where they play against each other. I therefore wish to restore the pre-trained policies to continue training on the new environment with a new policy_mapping_fn.
While I could achieve this using the algorithm.from_checkpoint() and then calling train(), I appreciate the helper functions of the ray.tune library, such as automated checkpoints and logging. Here is a pseudocode of my training process.
There seems to have been a change in behaviour between rllib 2.0 and 2.4, however, because in the second call to rune.run(), the policy_mapping_fn passed in the config is no longer being used with ray 2.4. Instead, it is using the one restored from the checkpoint.
I have looked into callbacks, but I cannot find a suitable one.
Similarly, while I believe the method algorithm.from_checkpoint() would set up the algorithm correctly, allowing me to update the policy_mapping_fn, to my knowledge I cannot pass the instance to a tune.run() call.
Is there a workaround where I can keep the behaviour from ray 2.0?
That is a useful example, thank you, because it shows how to update the workers.
However, I don’t know which callback is suitable. Ideally, I would do it once and only once. This would suggest on_algorithm_init, but I believe (to be confirmed) that this is called after the init(), but before restoring the training, which presumably is calling load_checkpoint(), which will then override the policy_mapping_fn.
yes. the on_algorithm_init is precisely called here:
So you can update the state of algorithm (which includes policy_mapping_fn) using any API that you would use otherwise. Something like this might work:
self.checkpoint_path = "<checkpoint_path>"
def on_algorithm_init(self, algo):
base_algo = Algorithm.from_checkpoint(self.checkpoint_path)
policy_map = based_algo.local_worker().policy_map
for pid, policy in policy_map:
# READ the API of add_policy to learn more
algo.add_policy(pid, policy, policy_mapping_fn, ...)
Do you have a code snippet of how you ended up setting up your code for the tune.run portion of this? I have a similar enva envb setup for training but my inherited DefaultCallbacks object doesn’t ever have its on_algorithm_init function run in algorithm.setup.
my current setup for this with effectively the same Callback code that is suggested is the following:
Thank you! This really helped. Here is the snippet of how I ended up doing it for my use case (used a factory method for setting up the restore path).
self.restore_path = restore_path
# Your custom methods here...
def on_algorithm_init(self, algorithm):
remove the 'learned' policy and replace with checkpointed one
for curriculum learning.
from ray.rllib.policy.policy import Policy # , PolicySpec
policy = Policy.from_checkpoint(self.restore_path)
for p_id, p in policy.items():