Setting `policies_to_train` on a per-episode granularity

High: Completely blocks me.

2. Environment:

  • Ray version: 2.49.1
  • Python version: 3.12.11
  • OS: Ubuntu

Hello, all. I’ve been working on a high-fidelity implementation of AlphaStar’s league algorithm within RLlib, and I think I’ve gotten pretty close, reconciling the differences between the paper and the released pseudocode, and accounting for the eccentricities of applying their algorithm to my target environment, which has a high enough draw rate that I had to change the PFSP implementation a bit to compensate. However, I think I’ve hit a roadblock in terms of infrastructure:

AlphaStar, notably, differentiates the ‘student’ agent from the ‘teacher’ agent, when assigning matches. The main exploiter will target the main agent the majority of the time, but the results of these matches are used only to optimize the exploiter, whereas main is only updated when it specifically looks for a challenging exploiter. This ensures that poorly-performing exploiters do not constitute a disproportionate share of the main agent’s training data. If this measure were not taken, it would lead to training collapse, as seen in my current results (shown below).[1]

As the exploiter falls behind with no means of catching up, main is increasingly incentivized to pursue suboptimal strategies that exploit its failings but leave it vulnerable to stronger opponents.

Ideally, I would like to be able to specify in my agent_to_module_mapping_fn that agent_to_train should update its weights after the episode, whereas its opponent, even if it is included in policies_to_train, should not be updated. I expect that my solution will have to be a little hacky to deal with such an unconventional ask, but I feel like there must be a preferred way to do this, which is more elegant than its alternatives. Any advice would be greatly appreciated.

def atm_fn(agent_id, episode, **kwargs):
    eid = hash(episode.id_)
    rng = np.random.default_rng(seed=abs(eid))
    r1 = rng.random()
    # The learning agent this episode is 'for', distributed evenly b/t agents
    agent_to_train = "main" if (r1 < 1/3) else "main_exploiter" if (r1 < 2/3) else "league_exploiter"
    if (eid % 2 == 0) != (agent_id==1):
      return agent_to_train
    # Select an opponent.
    if (agent_to_train=="main"): # opponents for main
      rand = rng.random()
      if (rand < .35): # 35% self play
        return "main"
      elif (rand < .85): # 50% PFSP (any other agent)
        valid_options = filter(lambda s: s!='main', agent_names)
      else: # 15% any agent with > 70% WR against main, or SP if none
        valid_options =list(filter(lambda s: wr[s]['main'] > 0.7, agent_names))
        if (len(valid_options)==0):
          return "main"
    elif (agent_to_train=="main_exploiter"): # opponents for ME
      wr_thresh_me = wr["main"]["main_exploiter"] / 9 # w/w+l >= 10%
      if (wr["main_exploiter"]["main"] > wr_thresh_me and rng.random() > .5):
        return "main" # 50% play versus main, if it's doing well
      # Otherwise PFSP against main's past copies
      valid_options = filter(lambda s: s[:6] == 'main_v', agent_names)
    else: # opponents for LE (all past players; fig 1)
      valid_options = filter(lambda s: '_v' in s, agent_names)
    # Run PFSP on our options
    valid_options = filter(lambda s: s not in just_added, valid_options)
    return pfsp(agent_to_train, list(valid_options), wr, rng)

  1. There is a remedial mechanism for main exploiters that fall behind, but even with this implemented, exploiters require experience against main to become useful, and the process of getting this experience is detrimental to main’s learning ↩︎

I ended up using episode ID to track the teacher and student agents, and set a loss mask for the episode results of the teacher agent in the Connector linking the sampler and the learner. My results indeed look neater as a result of this, and I’ve included my code for others that might want to do the same. Thinking of releasing my replication on Github in a while.

Maybe I’ll run it on Tic Tac Toe and also the soccer game used for the existing league/self-play examples, with visualizations for each.

Results:

Match Generation Code
# Get the agent that will be learning this episode, from the set of learning agents
def get_learning_agent(episode, policies_to_train):
  # Returns the agent (X or O) and the ID of the 'student' policy
  len_policies = len(policies_to_train)
  eid = hash(episode.id_) % (2*len_policies)
  agent_id = "X" if eid < len_policies else "O"
  policy_id = policies_to_train[eid%len_policies]
  return agent_id, policy_id

def pfsp(agent, opponents, wr, rng):
  #weights = np.array([wr[agent][o] * wr[o][agent] for o in opponents])
  if ('exploiter' not in agent):
    # Main agents want to not-lose, and want to learn from opponents they lose to.
    weights = np.array([wr[o][agent] for o in opponents])
  else:
    # Exploiter agents want to win, and want to learn from opponents they don't win against
    weights = np.array([(1-wr[agent][o]) for o in opponents])
  wr_sum = weights.sum()
  if (wr_sum == 0):
    return rng.choice(opponents)
  return rng.choice(opponents, p=weights/wr_sum)

def create_atm_fn(policies_to_train, agent_names, wr, just_added):
  '''
      Create an ATM function based on the released pseudocode
       - Treat "win" as any favorable outcome (win+draw/all for main, win/win+loss for exploiters)
      policies_to_train is the variable in the config with the same name
      agent_names is a list of all league agents' names.
      wr is a dictionary of win rates '''
  def atm_fn(agent_id, episode, **kwargs):
    # The learning agent this episode is 'for', distributed evenly b/t X and O.
    student_agent, student_policy = get_learning_agent(
        episode,
        policies_to_train
    )
    if (agent_id==student_agent):
      return student_policy
    eid = hash(episode.id_)
    rng = np.random.default_rng(seed=abs(eid))
    # Select an opponent.
    if (student_policy=="main"): # opponents for main
      rand = rng.random()
      if (rand < .35): # 35% self play
        return "main"
      elif (rand < .85): # 50% PFSP (any other agent)
        valid_options = filter(lambda s: s!='main', agent_names)
      else: # 15% any past main exploiter with 70% WR, SP otherwise
        # Adjust for draws
        valid_options =list(filter(lambda s: wr[s]['main'] > (.70/.30)*wr['main'][s], agent_names))
        if (len(valid_options)==0):
          return "main"
    elif (student_policy[:14] == "main_exploiter"): # opponents for ME
      wr_thresh_me = wr["main"]["main_exploiter"] / 9 # w/w+l >= 10%
      if (wr["main_exploiter"]["main"] > wr_thresh_me):
        return "main" # play versus main, if it's doing passably
      # Otherwise PFSP against main's past copies
      valid_options = filter(lambda s: s[:6] == 'main_v', agent_names)
    else: # opponents for LE (all past players; fig 1)
      valid_options = filter(lambda s: '_v' in s, agent_names)
    # Run PFSP on our options
    valid_options = filter(lambda s: s not in just_added, valid_options)
    return pfsp(student_policy, list(valid_options), wr, rng)
  return atm_fn
Post-Processing Code
    def mask_teacher_batches(self, meps, batch):
      for aid in batch:
        b_obs = batch[aid][Columns.OBS]
        if (Columns.LOSS_MASK not in batch[aid].keys()):
          batch[aid][Columns.LOSS_MASK] = torch.ones((b_obs.shape[0],), dtype=torch.bool).to(b_obs.device)
      start_indices = defaultdict(lambda: 0)
      lc = 0
      for mep in meps:
        student_agent, student_policy = get_learning_agent(mep, self.policies_to_train)
        x_ep, o_ep = mep.agent_episodes['X'], mep.agent_episodes['O']
        x_mid, o_mid = x_ep.module_id, o_ep.module_id
        x_l, o_l = len(x_ep), len(o_ep)
        x_s = start_indices[x_mid]
        start_indices[x_mid]+=x_l
        o_s = start_indices[o_mid]
        start_indices[o_mid]+=o_l
        if (x_mid!=student_policy):
          batch[x_mid][Columns.LOSS_MASK][x_s:x_s+x_l] = False
        elif (o_mid!=student_policy):
          batch[o_mid][Columns.LOSS_MASK][o_s:o_s+o_l] = False