Change policy mapping function in the middle of an algorithm

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I’m trying to implement my own RL algorithm based on PPO, and I ran into some problems concerning changing policy mapping function.
My implementation looks like this:

class MyAlgorithm(PPO):
    def training_step(self):
        def new_policy_mapping_fn(agent_id, episode, worker, **kwargs):
            # the implementation here is different for each call of training_step
        self.workers.foreach_worker(lambda w: w.set_policy_mapping_fn(new_policy_mapping_fn))
        super().training_step()

I then ran my algorithm using algo.train().
Then it comes the wierd part.
In the first run of algo.train(), everything looks fine.
However, in the second run of algo.train(), I noticed that the keys of the sampled batch are still the policy ids used in the first iteration. Besides, in the results returned by algo.train(), the reward information is also related to the policy mapping function used in the first iteration, and unrelated to the policy mapping function used in the second iteration.
Is there a way to fix this?