Odd behavior when reconfiguring the mapping function during a league

1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.

2. Environment:

  • Ray version: Latest
  • Python version: 3.10
  • OS: Ubuntu

3. What happened vs. what you expected:

Brief context: My league system, based off of the one in the example code, saves and freezes the main agent whenever its win rate against all prior instances exceeds a certain threshold.

  • Expected: During the epoch after a new opponent is added, the main agent plays against it and the other opponents, recording accurate win rates for all of them.
  • Actual: During the epoch after a new opponent is added, the win rate against the new agent is nearly identical to the win rate against a random agent. However, this issue is exclusive to the epoch immediately after the addition of the new opponent, and plausible win rates against this agent are reported in every epoch that follows.

I’ve gone over my code several times to try to identify an error in logic that would cause this, but haven’t found anything. As best I can tell, the env runner is reporting that the newest opponent is being deployed against the main agent, but is actually deploying a different opponent. Did I miss something while putting together my callback function, or is there a propagation issue on RLlib’s end?

My Code (adapted from example)
from collections import defaultdict

import numpy as np

from ray.rllib.callbacks.callbacks import RLlibCallback
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS

class SelfPlayCallback(RLlibCallback):
    def __init__(self, win_rate_threshold):
        super().__init__()
        # 0=main_v0, 1=main_v1, 2=2nd main policy snapshot, etc..
        self.current_opponent = 0
        self.win_rate_threshold = win_rate_threshold
        # Report the matchup counters (who played against whom?).
        self._matching_stats = defaultdict(int)
        # Hacky fix for the new agent WR bug
        self.just_added = False

    def on_episode_end(
        self,
        *,
        episode,
        env_runner,
        metrics_logger,
        env,
        env_index,
        rl_module,
        **kwargs,
    ) -> None:
        # Compute the win rate for this episode and log it with a window of 100.
        main_agent = 'X' if episode.module_for('X') == "main" else 'O' # agent that our main policy was
        opposing_agent = episode.module_for('X') if main_agent=='O' else episode.module_for('O')
        opposing_agent = opposing_agent.split('_v')[-1]
        rewards = episode.get_rewards()
        assert main_agent in rewards
        main_won = rewards[main_agent][-1] == 1.0
        main_lost = rewards[main_agent][-1] == -1.0
        metrics_logger.log_value(
            f"win_rate_{opposing_agent}",
            main_won,
            window=1000,
        )
        metrics_logger.log_value(
            f"loss_rate_{opposing_agent}",
            main_lost,
            window=1000,
        )

    def update_atm_fn(self, algorithm, loss_rates):
        #
        base_probs = loss_rates + .01 # add a small chance for everything
        base_probs = base_probs / base_probs.sum()
        print(f"Updating ATM fn: {base_probs}")
        # Reweight and (if applicable) add to agent randomizer
        def agent_to_module_mapping_fn(agent_id, episode, **kwargs):
            opponent = "main_v{}".format(
                np.random.choice(list(range(1, self.current_opponent + 1)),p=base_probs)
            )
            if ((hash(episode.id_) % 2 == 0) != (agent_id=='X')):
                self._matching_stats[("main", opponent)] += 1
                return "main"
            else:
                return opponent
        # Set new mapping function
        algorithm.config._is_frozen = False
        algorithm.config.multi_agent(policy_mapping_fn=agent_to_module_mapping_fn)
        algorithm.config.freeze()
        # Add to (training) EnvRunners.
        def _add(_env_runner, _module_spec=None):
            _env_runner.config.multi_agent(
                policy_mapping_fn=agent_to_module_mapping_fn,
            )
            return MultiRLModuleSpec.from_module(_env_runner.module)
        algorithm.env_runner_group.foreach_env_runner(_add)

    def on_train_result(self, *, algorithm, metrics_logger=None, result, **kwargs):
        print(f"Iter={algorithm.iteration}:")
        print(f"Matchups: {dict(self._matching_stats)}")
        f_wr = f_lr = 0
        worst_ratio = 1 # worst ratio must exceed threshold
        loss_rates = np.array([result[ENV_RUNNER_RESULTS][f"loss_rate_{i}"] for i in range(self.current_opponent+1)])
        for i in range(self.current_opponent+1):
          if (i==0 and self.current_opponent != 0):
            continue
          win_rate = result[ENV_RUNNER_RESULTS][f"win_rate_{i}"]
          loss_rate = loss_rates[i]
          sum_rates = (win_rate + loss_rate)
          ratio = win_rate / sum_rates if sum_rates > 0 else 1
          print(f"Opponent {i}: win-rate={win_rate:.2f} loss-rate={loss_rate:.2f} ... ", end="" if i == self.current_opponent else '\n')
          f_wr, f_lr = win_rate, loss_rate
          if (ratio < worst_ratio):
            worst_ratio = ratio

        # If win rate is good versus the most recent opponent -> Snapshot current policy and play against
        # it next, keeping the snapshot fixed and only improving the "main"
        # policy.
        if (self.just_added):
            self.just_added = False
            print("just added new agent; allowing an epoch to propagate.")
        elif f_wr > 0 and worst_ratio > self.win_rate_threshold:
            self.current_opponent += 1
            new_module_id = f"main_v{self.current_opponent}"
            print(f"adding new opponent to the mix ({new_module_id}).")

            # Reset stored values
            for i in range(self.current_opponent):
                metrics_logger.set_value(
                    f"win_rate_{i}",
                    0,
                    window=1000,
                )
                metrics_logger.set_value(
                    f"loss_rate_{i}",
                    0,
                    window=1000,
                )
            loss_rates = np.append(loss_rates, [0.5])
            main_module = algorithm.get_module("main")
            algorithm.add_module(
                module_id=new_module_id,
                module_spec=RLModuleSpec.from_module(main_module), # Copy main module specs
            )
            # TODO (sven): Maybe we should move this convenience step back into
            #  `Algorithm.add_module()`? Would be less explicit, but also easier.
            algorithm.set_state(
                {
                    "learner_group": {
                        "learner": {
                            "rl_module": {
                                new_module_id: main_module.get_state(),
                            }
                        }
                    }
                }
            )
            self.just_added = True
        else:
            print("not good enough; will keep learning ...")

        # Update mapping function, reweighting and adding new module if needed
        if (self.current_opponent > 0):
          self.update_atm_fn(algorithm, loss_rates[1:])

        # +2 = main + random
        result["league_size"] = self.current_opponent + 2
Config (in case it's relevant)
import numpy as np
import functools

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import DefaultPPOTorchRLModule

from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule,
)

specs = {
    "rand": RLModuleSpec(
        module_class=RandHeuristicRLM,
    ),
    "block_win": RLModuleSpec(
        module_class=BlockWinHeuristicRLM,
    ),
    "perfect": RLModuleSpec(
        module_class=PerfectHeuristicRLM,
    ),
}
heuristics = list(specs.keys())

for n in ['main', 'main_v0']: # default frozen policy, and first learned policy
    p = n
    specs[p] =  RLModuleSpec(
        module_class=CMAPPOActionMaskingTorchRLModule,
        model_config={
            "head_fcnet_hiddens": (64,64),
        }
    )

# A shared critic. Moot in this use-case, but I'm running other experiments that make use of it
single_agent_env = TicTacToe()
specs[SHARED_CRITIC_ID] = RLModuleSpec(
        module_class=ActionMaskingSharedCriticTorchRLModule,
        observation_space=single_agent_env.observation_spaces['X'],
        action_space=single_agent_env.action_spaces['X'],
        learner_only=True, # Only build on learner
    )

# League stuff
win_rate_threshold = 0.95 # wins / wins+losses, wins > 2
def agent_to_module_mapping_fn(agent_id, episode, **kwargs):
    # agent_id = [0|1] -> module depends on episode ID
    # This way, we make sure that both modules sometimes play agent0
    # (start player) and sometimes agent1 (player to move 2nd).
    return "main" if ((hash(episode.id_) % 2 == 0) != (agent_id=='X')) else "main_v0"

config = (
    CMAPPOConfig()
    .environment(TicTacToe, env_config={})
    .callbacks( # set up our league
        functools.partial(SelfPlayCallback,
            win_rate_threshold=win_rate_threshold,
        )
    )
    .env_runners(
        num_env_runners=0,
        num_envs_per_env_runner=1,
        batch_mode="complete_episodes", # use_logits needs this to work.
    )
    .evaluation(
        custom_evaluation_function=custom_eval_function,
        evaluation_interval=10, # every K training steps
        evaluation_duration=100,
        evaluation_config={
            'heuristics': heuristics,
            'ensemble': ['main'], # policies to use when calculating winrates
        }
    )
    .multi_agent(
          policies=['main','main_v0']+heuristics,
          policy_mapping_fn=agent_to_module_mapping_fn,
          # Only the learned policy should be trained.
          policies_to_train=['main'],
      )
    .training(
        learner_class=CMAPPOTorchLearner,
        lr=1e-4,
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs=specs
        ),
    )
)

algo = config.build_algo()

Example of the issue:

Iter=87:
Matchups: {('main', 'main_v1'): 26189, ('main', 'main_v2'): 4518, ('main', 'main_v3'): 2686}
Opponent 1: win-rate=0.77 loss-rate=0.03 ... 
Opponent 2: win-rate=0.19 loss-rate=0.00 ... 
Opponent 3: win-rate=0.15 loss-rate=0.00 ... adding new opponent to the mix (main_v4).
Updating ATM fn: [0.0637213  0.01766564 0.01766564 0.90094743]
iter=87 R=[('O', '-0.17'), ('X', '0.17')]

Iter=88:
Matchups: {('main', 'main_v1'): 26236, ('main', 'main_v2'): 4527, ('main', 'main_v3'): 2694, ('main', 'main_v4'): 570}
Opponent 1: win-rate=0.75 loss-rate=0.00 ... 
Opponent 2: win-rate=0.07 loss-rate=0.00 ... 
Opponent 3: win-rate=0.16 loss-rate=0.13 ... 
Opponent 4: win-rate=0.93 loss-rate=0.02 ... just added new agent; allowing an epoch to propagate.
Updating ATM fn: [0.07095589 0.0501293  0.72382283 0.15509198]
iter=88 R=[('O', '-0.07'), ('X', '0.07')]

Iter=89:
Matchups: {('main', 'main_v1'): 26279, ('main', 'main_v2'): 4549, ('main', 'main_v3'): 3058, ('main', 'main_v4'): 661}
Opponent 1: win-rate=0.84 loss-rate=0.09 ... 
Opponent 2: win-rate=0.17 loss-rate=0.00 ... 
Opponent 3: win-rate=0.28 loss-rate=0.00 ... 
Opponent 4: win-rate=0.20 loss-rate=0.38 ... not good enough; will keep learning ...
Updating ATM fn: [0.18971604 0.01995395 0.01995395 0.77037606]
iter=89 R=[('O', '-0.30'), ('X', '0.30')]

As seen above, the win rate for opponent 4 is reported as 93% during the epoch after it’s instantiated, but falls to a much more plausible value after this ceases to be the case. I do adjust the probability weights to (generally) favor the new opponent, but this seems much too drastic to be an accurate representation of the impact of that, especially since later agents are good enough at defensive play that they can force a draw the majority of the time.

Also, my method for reweighting the mapping function at each epoch feels a little hacky, but I wasn’t able to identify a neater way to do it. Would appreciate any advice on that front, if there’s a better way to do things.

TL;DR: The env runner appears to be reporting that it deployed one module as the main policy’s opponent while actually having deployed a different one.