Updating policy_mapping_fn while using tune.run() and restoring from a checkpoint

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

In my experiments, I wish to pre-train some agents with self-play on an environment (call it env a) before switching to training on a new environment (call it env b) where they play against each other. I therefore wish to restore the pre-trained policies to continue training on the new environment with a new policy_mapping_fn.

While I could achieve this using the algorithm.from_checkpoint() and then calling train(), I appreciate the helper functions of the ray.tune library, such as automated checkpoints and logging. Here is a pseudocode of my training process.

config = PPOConfig().environment(env=a).multi_agent(policy_mapping_fn=y).rollouts(num_rollout_workers=0)

trial = tune.run("PPO", config=config, checkpoint_at_end=True)
restore = trial.checkpoint.dir_or_data

config = config.environment(env=b).multi_agent(policy_mapping_fn=z)

tune.run(config, restore=restore)

There seems to have been a change in behaviour between rllib 2.0 and 2.4, however, because in the second call to rune.run(), the policy_mapping_fn passed in the config is no longer being used with ray 2.4. Instead, it is using the one restored from the checkpoint.

I have looked into callbacks, but I cannot find a suitable one.
Similarly, while I believe the method algorithm.from_checkpoint() would set up the algorithm correctly, allowing me to update the policy_mapping_fn, to my knowledge I cannot pass the instance to a tune.run() call.

Is there a workaround where I can keep the behaviour from ray 2.0?

Hello @Muff2n ,

One example that might be inspiring is this one:

It uses the on_train_result hook of the callback to add a new policy and update the policy_mapping_fn.

If you need to load the checkpoint upon initialization of the callback you can override the on_algorithm_init hook of the callbacks and update the algorithm instance on the fly.

That is a useful example, thank you, because it shows how to update the workers.

However, I don’t know which callback is suitable. Ideally, I would do it once and only once. This would suggest on_algorithm_init, but I believe (to be confirmed) that this is called after the init(), but before restoring the training, which presumably is calling load_checkpoint(), which will then override the policy_mapping_fn.

Do you have any thoughts on this matter please?

yes. the on_algorithm_init is precisely called here:

So you can update the state of algorithm (which includes policy_mapping_fn) using any API that you would use otherwise. Something like this might work:


class MyCustomCallback(DefaultCallback):
     def __init__(self):
           self.checkpoint_path = "<checkpoint_path>"

     def on_algorithm_init(self, algo):
          base_algo = Algorithm.from_checkpoint(self.checkpoint_path)
          policy_map = based_algo.local_worker().policy_map
          for pid, policy in policy_map:
                # READ the API of add_policy to learn more
                algo.add_policy(pid, policy, policy_mapping_fn, ...)

Thank you, I have managed to get this working.

1 Like

Hi @Muff2n ,

Do you have a code snippet of how you ended up setting up your code for the tune.run portion of this? I have a similar enva envb setup for training but my inherited DefaultCallbacks object doesn’t ever have its on_algorithm_init function run in algorithm.setup.

my current setup for this with effectively the same Callback code that is suggested is the following:


        env_b = "RockPaperScissorsCsaben"
        pmf_a = select_policy

        config_a = (
            AlgorithmConfig(algo_class=self.algorithm)
            .environment("RockPaperScissors")
            .framework(self.framework)
            .rollouts(
                num_rollout_workers=0,
                num_envs_per_worker=4,
                rollout_fragment_length=10,
            )
            .training(
                train_batch_size=200,
                gamma=0.9,
            )
            .multi_agent(
                policies={
                    "always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
                    "beat_last": PolicySpec(policy_class=BeatLastHeuristic),
                    "always_slow": PolicySpec(policy_class=AlwaysTooSlowPolicy),
                    "learned": PolicySpec(
                        config=AlgorithmConfig.overrides(
                            model={"use_lstm": True},
                            framework_str=self.framework,
                        )
                    ),
                },
                policy_mapping_fn=pmf_a,
                policies_to_train=["learned"],
            )
            .reporting(metrics_num_episodes_for_smoothing=200)
            # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
            .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
        )

        # heuristics break everything otherwise
        config_a.checkpointing(
            export_native_model_files=False, checkpoint_trainable_policies_only=True
        )

        from ray import tune

        trial_a = tune.run(
            "PPO",
            config=config_a,
            stop={"training_iteration": 1},
            checkpoint_at_end=True,
        )
        # restore_path = trial_a.checkpoint.dir_or_data
        restore_path = trial_a.get_best_checkpoint(
            trial_a.get_best_trial(),
            mode="max",
            return_path=True,
        )

        config_b = config_a.environment(env_b).multi_agent(policy_mapping_fn=pmf_a)
        from ray.rllib.algorithms.callbacks import MultiCallbacks

        from harness.callbacks import MyCustomCallback

        config_b["callbacks"] = MultiCallbacks([MyCustomCallback(restore_path)])
        from ray import air

        tuner = tune.Tuner(
            "PPO",
            param_space=config_b,
            run_config=air.RunConfig(
                stop={"training_iteration": 1},
            ),
        )

        tuner.fit()
1 Like

The callbacks are implemented in this file:

in function run_indep.

checkpoints is a list of strings to policy checkpoints that can be loaded before tune.run is called on the second environment.

    if checkpoints is not None:
      class MyCallbacks(DefaultCallbacks):
        def __init__(self):
          super().__init__()
          self.checkpoints = checkpoints
          self.policy_mapping_fn = lambda aid, episode, worker, **kwargs: aid

        def on_algorithm_init(
            self,
            *,
            algorithm: "Algorithm",
            **kwargs,
        ) -> None:
          for checkpoint in self.checkpoints:
            policy = Policy.from_checkpoint(checkpoint)
            for p_id, p in policy.items():
              algorithm.add_policy(p_id, policy=p)

          algorithm.remove_policy(DEFAULT_POLICY_ID,
                                  policy_mapping_fn=self.policy_mapping_fn,
                                  policies_to_train=list(POLICIES.keys()))

      config = config.callbacks(MyCallbacks)
2 Likes

Thank you! This really helped. Here is the snippet of how I ended up doing it for my use case (used a factory method for setting up the restore path).


callbacks.py

def create_callback_class(restore_path):
    class MyCustomCallbackWithRestorePath(DefaultCallbacks):
        def __init__(self):
            super().__init__()
            self.restore_path = restore_path

        # Your custom methods here...
        def on_algorithm_init(self, algorithm):
            """
            remove the 'learned' policy and replace with checkpointed one
            for curriculum learning.
            """
            from ray.rllib.policy.policy import Policy  # , PolicySpec

            policy = Policy.from_checkpoint(self.restore_path)
            algorithm.remove_policy("learned")
            for p_id, p in policy.items():
                algorithm.add_policy(p_id, policy=p)

    return MyCustomCallbackWithRestorePath


trainer.py


restore_path = trial_a.get_best_checkpoint(
    trial_a.get_best_trial(),
    mode="max",
    return_path=True,
)

config_b = (
    config_a.environment(env_b)
    .multi_agent(policy_mapping_fn=pmf_b)
    .callbacks(create_callback_class(restore_path))
)
1 Like