Resuming/extending rllib tune experiments

So, I have custom PPOTrainer / PPOConfig classes and have implemented getstate and setstate methods for them.

import logging
import numpy as np
from collections import OrderedDict
from typing import List, Optional, Type, Union, Dict
import math
import time

from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.rollout_ops import (
from ray.rllib.execution.train_ops import (
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (

from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.typing import AlgorithmConfigDict, ResultDict
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.metrics import (

from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
from ray.rllib.utils.annotations import override

from rollout_wrapper import CRW

import torch

logger = logging.getLogger(__name__)

class NoveltyRewardPPOConfig(PPOConfig):
    def __init__(self, algo_class=None):
        # knn to calculate meta-reward against
        self.k = 5
        # whether to use the meta-reward or not
        self.meta_reward = True
        # hyperparameters for the self-play callback
        # how often to snapshot the current policy
        self.s = 1000
        # when applicable, above what win-rate threshold should we snapshot the policy
        self.win_rate_threshold = 0.95
        # above what win-rate threshold should we switch to the meta-reward
        self.mc_threshold = 0.25
        # threshold for when a current profile / policy should be saved into the respective buffers
        self.novelty_threshold = 0.15
        # flag to determine if you return the novelty reward or a random number
        self.return_noise = False
        # get the dominant strategy
        self.dominant_strategy = []

    def to_dict(self):
        config = super().to_dict()
        config["k"] = self.k
        config["meta_reward"] = self.meta_reward
        config["s"] = self.s
        config["win_rate_threshold"] = self.win_rate_threshold
        config["mc_threshold"] = self.mc_threshold
        config['novelty_threshold'] = self.novelty_threshold
        config['random_reward'] = self.return_noise
        config['dominant_strategy'] = self.dominant_strategy
        return config

    def hparams(self, *,
                knn: Optional[int] = 5,
                meta_reward: Optional[bool] = True,
                snapshot_timer: Optional[int] = 25,
                novelty_threshold: Optional[float] = 0.5,
                win_rate_threshold: Optional[float] = 0.95,
                mc_threshold: Optional[float] = 0.8,
                random_reward: Optional[bool] = False,
                dominant_strategy: Optional[tuple] = (1, 2, 6)) -> "PPOConfig":
        """Returns a copy of this config with the given hyperparameters"""
        self.k = knn
        self.meta_reward = meta_reward
        self.s = snapshot_timer
        self.win_rate_threshold = win_rate_threshold
        self.mc_threshold = mc_threshold
        self.novelty_threshold = novelty_threshold
        self.return_noise = random_reward
        self.dominant_strategy = list(dominant_strategy)
        return self

    def update_from_dict(self, config_dict) -> "AlgorithmConfig":
        for k in ["k", "meta_reward", "s", "win_rate_threshold", "mc_threshold", 'novelty_threshold',
                  'random_reward', 'dominant_strategy']:
            if k in config_dict:
                setattr(self, k, config_dict.pop(k))
        return super().update_from_dict(config_dict)

class NoveltyRewardPPO(PPO):
    def get_default_config(cls) -> AlgorithmConfig:
        return NoveltyRewardPPOConfig()

    def __init__(self, config=None, env=None, logger_creator=None):
        # todo make this a real replay buffer object to take
        #  advantage of things like prioritized sampling to determine
        #  which samples to use for meta reward-shaping / learning
        self.reward_buffer = OrderedDict()
        self.passed_mc = False
        self.meta_reward_value = 0
        self.n_opponents = 0
        self.learner_key = 'main0' if not config['env_config'].get('name', False) == 'connect_four' else 'main'
        # configs passed in through the custom_model_config
        # cmc = config['multiagent']['policies'][config['env_config']["players_ids"][1]][3]
        _config = config.get('model', {}).get('custom_model_config', {})
        # _config = config
        # Assume passed a customPPOConfig object that contains these extra parameters
        # knn to calculate meta-reward against
        self.k = _config.get("k", 5)
        # whether to use the meta-reward or not
        self.meta_reward = _config.get("meta_reward", True)
        # hyperparameters for the self-play callback
        # how often to snapshot the current policy
        self.s = _config.get("s", 500)
        # when applicable, above what win-rate threshold should we snapshot the policy
        self.win_rate_threshold = _config.get("win_rate_threshold", 0.95)
        # above what win-rate threshold should we switch to the meta-reward
        self.mc_threshold = _config.get("mc_threshold", 0.8)
        # threshold for when a current profile / policy should be saved into the respective buffers
        self.novelty_threshold = _config.get("novelty_threshold", 0.5)
        # if the novelty reward is greater than the threshold, then we save the current profile and policy
        self.save_policy_trigger = False
        # flag to determine if you return the novelty reward or a random number
        self.return_noise = _config.get('random_reward', False)
        # get the dominant strategy
        self.dominant_strategy = _config.get('dominant_strategy', [])

        # at the end so that the self-play callback has access to these attributes above
        super().__init__(config, env, logger_creator)

        # is the current policy a dual-path policy?
        self.dual_path = False
        self.exploit = False
        if env is not None:
            obs_dim = env.observation_space.shape[0]
            if obs_dim == 127:
                self.dual_path = True
            env = self.env_creator(config['env_config'])
            obs_dim = env.observation_space.shape[0]
            if obs_dim == 127:
                self.dual_path = True

        # initialize a worker to compute rewards on demand
        self.policy_config = {'model': {'custom_model': _config.get('network', 'mlp')}}
        # custom rollout worker: band-aid fix for rllib not
        # letting me do e.g., algorithm.evaluate(p1, p2)
        self.crw = CRW(self.env_creator,

    class BufferElement:
        BufferElement is a class that represents a single element of the buffer.
        It is used to compute the distance between the current reward and the
        rewards in the full buffer.

        Each dimension of the reward is represented by a single value that is the
        reward of the agent that is playing in that dimension.

        def __init__(self, reward_dict, mask, neighbors=10):
            assert isinstance(reward_dict, OrderedDict)
            self.reward_dict = reward_dict
            self.reward_keys = reward_dict.keys()
            self.k = neighbors
            self.mask = mask
            # [r_1, ..., r_n]
            self.values = np.array(list(self.reward_dict.values()))
            # self.sample_mean = np.mean(self._values)
            # self.values = self._values - self.sample_mean

        def add(self, k, v):
            self.reward_dict[k] = v
            self.reward_keys = self.reward_dict.keys()
            self.values = np.array(list(self.reward_dict.values()))
            self.mask = np.append(self.mask, False)

        def __sub__(self, other):
            # assume is a list of BufferElement objects
            if isinstance(other, OrderedDict) and len(other) >= 2:
                buffer_values = np.array([o.values[self.mask] for o in other.values()])
                total_reward, credited_rewards = self._assign_credit(buffer_values)
            elif isinstance(other, OrderedDict) and len(other) == 1:
                buffer_values = np.array([list(other.values())[0].values[self.mask]])
                total_reward, credited_rewards = self._assign_credit(buffer_values)
            # if len(other) == 0:
                total_reward = 1
                credited_rewards = [total_reward for _ in range(len(self.reward_keys))]

            return total_reward, credited_rewards

        def _assign_credit(self, masked_profile_matrix):
            # let's do some credit assignment!
            profile_matrix = torch.tensor(masked_profile_matrix, requires_grad=True)
            current_profile = torch.tensor(self.values[self.mask], requires_grad=True)
            torch_mask = torch.tensor(self.mask, requires_grad=False)

            difference = current_profile - profile_matrix
            dists = torch.norm(difference, dim=1)
            sorted_dists, sorted_indices = torch.sort(dists)
            important_indices = sorted_indices[:self.k]

            indicator = torch.zeros(profile_matrix.shape[0])
            indicator[important_indices] = 1

            total_reward = torch.sum(dists * indicator) / torch.sum(indicator)
            # normalized_total_reward = total_reward / torch_mask.sum()
            normalized_total_reward = total_reward / (2 * torch.sqrt(torch.tensor(torch_mask.sum(),

            # get the gradient with respect to the current profile vector

            # save total reward to python value
            # divide by the number of agents involved in the novelty calculation
            # this should normalize the score so it doesn't always go up?
            total_reward_python = normalized_total_reward.item()

            # push the grad through the softmax
            current_profile_grad = current_profile.grad
            # ??? should I split the meta-reward into the individual rewards?
            # by the softmax of the gradient? Or just provide the gradient?

            # version 0: just provide the meta-reward
            reward = [total_reward_python for _ in range(len(current_profile_grad))]

            # version 1: split the meta-reward into the individual rewards
            # by the softmax of the gradient
            # current_profile_grad_01 = torch.softmax(current_profile_grad, dim=0)
            # grad_reward = current_profile_grad_01 * total_reward
            # # get the values back out of the tensor
            # grad_reward = grad_reward.detach().numpy()

            # version 2: just provide the gradient
            # reward = current_profile_grad.detach().numpy()

            # version 3: weight the total reward by the gradient
            # reward = total_reward_python * current_profile_grad.detach().numpy()

            # version 4: weight the total reward by the abs marginal contribution of each agent
            # get the nearest neighbor differences and
            # compute the mean contribution of each dimension
            # mean_contribution_per_snapshot = torch.mean(difference[sorted_indices][:self.k], dim=0)
            # # split the scalar value reward into the individual contributions
            # percent_contrib = (torch.abs(mean_contribution_per_snapshot) /
            #                    torch.sum(torch.abs(mean_contribution_per_snapshot))).detach().numpy()
            # reward = total_reward_python * percent_contrib

            reward = self._add_in_zeros(reward)
            return total_reward_python, reward

        def _add_in_zeros(self, reward):
            foo = []
            r_index = 0
            for m in self.mask:
                if m:
                    r_index += 1
            return np.array(foo)

        def __getstate__(self):
            state = {}
            # make dict from ordered dict
            state['reward_keys'] = list(self.reward_keys)
            state['reward_values'] = self.values.tolist()
            state['mask'] = self.mask.tolist()
            state['k'] = self.k

            return state

        def __setstate__(self, state):
            keys = state['reward_keys']
            values = state['reward_values']
            self.reward_dict = OrderedDict(zip(keys, values))
            self.reward_keys = self.reward_dict.keys()
            self.values = np.array(list(self.reward_dict.values()))
            self.mask = np.array(state['mask'])
            self.k = state['k']

        def __repr__(self):
            return f"BufferElement({self.reward_dict}, {self.mask}, {self.k})"

    def _update_batch_for_trainable_with_meta_reward(self, train_batch: "MultiAgentBatch"):
        This function updates the batch for the main0 policy with the meta reward.
        The meta reward is the distance between the current reward and the rewards
        in the buffer.
        policy_ids_in_batch = list(train_batch.policy_batches.keys())
        snap_ids_in_batch = [pid for pid in policy_ids_in_batch if not pid == self.learner_key]

        main_agent_state = self.get_policy(self.learner_key).get_state()

        main_batch = train_batch.policy_batches[self.learner_key]
        snap_batches = {pid: train_batch.policy_batches[pid] for pid in snap_ids_in_batch}

        # split batches by episode
        episodic_main_batches = main_batch.split_by_episode()
        episodic_snap_batches = {pid: snap_batches[pid].split_by_episode() for pid in snap_ids_in_batch}

        # if self.dual_path:
        #     # split apart batches used for optimization
        #     opt_batches = SampleBatch()
        #     nov_batches = SampleBatch()
        #     for ep_batch in episodic_main_batches:
        #         # exploitation episode
        #         if ep_batch[ep_batch.OBS][0][-1] == 0:
        #             opt_batches = opt_batches.concat(ep_batch)
        #         # exploration episode
        #         else:
        #             nov_batches = nov_batches.concat(ep_batch)
        #     episodic_main_batches = nov_batches.split_by_episode()

        # calculate average score of each snapshop policy using total reward
        snap_scores = {pid: np.mean([ep_batch[ep_batch.REWARDS].sum() for ep_batch in episodic_snap_batches[pid]])
                       for pid in snap_ids_in_batch}

        # fill in missing values in the buffer sample
        # this uses current main0
        # if this wasn't in the training batch, mask it out of the meta_reward calculation
        # but still fill the value in for use later!!
        comparison_mask = []
        for n in list(np.arange(1, self.n_opponents + 1)):
            pid = 'random_r0' if n == 0 else f'main_v{n}'
            if pid not in snap_scores:
                opponent_pol = self.get_policy(pid)
                if opponent_pol is not None:
                    opponent_state = opponent_pol.get_state()
                    opponent_state = self.get_policy('random').get_state()
                rewards = self.crw.get_discounted_rewards(opponent_state, main_agent_state)
                    snap_scores[pid] = rewards['random_r0']
                except KeyError:
                    # 0 = random
                    # 1 = main
                    snap_scores[pid] = rewards[0]
                    snap_scores[pid] = 0

        comparison_mask = np.array(comparison_mask)

        # sort snapshot rewards by name e.g., main_v2
        sorted_snap_scores = OrderedDict()
        for k in sorted(snap_scores, key=lambda x: int(x.split('_')[1][1:])):
            sorted_snap_scores[k] = snap_scores[k]

        # calculate distance between snapshot policy rewards and buffer examples
        snap_sample_reward = self.BufferElement(sorted_snap_scores, comparison_mask, neighbors=self.k)
        total_reward, credited_rewards = snap_sample_reward - self.reward_buffer

        # save the meta reward value for logging
        self.meta_reward_value = total_reward

        # add to buffer if total reward is greater than novelty threshold
        if total_reward >= self.novelty_threshold:
            self.reward_buffer[self.iteration] = snap_sample_reward
            self.save_policy_trigger = True

        main_policy = self.get_policy(self.learner_key)

        # set rewards in main0 batch
        # give each batch the percent of the total reward that the opponent contributed
        ordered_keys = sorted_snap_scores.keys()
        recombined_main_batch = SampleBatch()
        for ep_batch in episodic_main_batches:
            episode_id = ep_batch[ep_batch.EPS_ID][0]
            # find which opponent was used in this episode
            opponent_id = None
            for pid in snap_ids_in_batch:
                for ep in episodic_snap_batches[pid]:
                    if ep[ep.EPS_ID][0] == episode_id:
                        opponent_id = pid
                if opponent_id is not None:
            if opponent_id is None or opponent_id == self.learner_key:
                # what the fuck, this shouldn't be possible

            reward_index = list(ordered_keys).index(opponent_id)
            ep_batch[ep_batch.REWARDS] = np.zeros_like(ep_batch[ep_batch.REWARDS])
            if self.return_noise:
                ep_batch[ep_batch.REWARDS][-1] = torch.randn(1)
                ep_batch[ep_batch.REWARDS][-1] = credited_rewards[reward_index]

            # recalculate advantages using new rewards
            #  use the main policy since that's the batch we're updating
            updated_batch = compute_gae_for_sample_batch(policy=main_policy,

            recombined_main_batch = recombined_main_batch.concat(updated_batch)

        # if self.dual_path:
        #     # add optimization batches back in
        #     recombined_main_batch = recombined_main_batch.concat(opt_batches)
        # place updated version back into the train_batch we were sent
        train_batch.policy_batches[self.learner_key] = recombined_main_batch
        return train_batch

    def used_dom_to_win(self, train_batch: "MultiAgentBatch") -> float:
        # check if we used the dominant strategy
        # calculate the number of rollouts where we won via the dominant strategy

        main_batches = train_batch.policy_batches[self.learner_key].split_by_episode()
        # bar = train_batch.policy_batches['random'].split_by_episode()
        dominant_strategy_wins = 0
        for batch in main_batches:
            actions = batch[batch.ACTIONS]
            # calculate if we used the dominant strategy
            # actions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 5, 3, 4, 6, 6, 1, 2, 6]
            # dominant_strategy = [1, 2, 6]
            # check if the dominant strategy (in the order given) occurs in the actions list
            # if it does, we used the dominant strategy
            # if it doesn't, we didn't use the dominant strategy
            start = len(self.dominant_strategy)
            end = len(actions)
            for i in range(end - start + 1):
                subset = np.array(actions[i:i + start])
                if subset.size > 0:
                    match = subset == self.dominant_strategy
                    if np.all(match):
                        dominant_strategy_wins += 1

        return dominant_strategy_wins / len(main_batches)

    def training_step(self) -> ResultDict:
        # Collect SampleBatches from sample workers until we have a full batch.
        if self.config.count_steps_by == "agent_steps":
            train_batch = synchronous_parallel_sample(
            train_batch = synchronous_parallel_sample(
                worker_set=self.workers, max_env_steps=self.config.train_batch_size

        train_batch = train_batch.as_multi_agent()

        self.percent_dom_used = self.used_dom_to_win(train_batch)
        self.save_policy_trigger = False
        # update batch for main0 policy with meta reward
        # from the other snapshot policies
        # todo put in an if statement to only do this if:
        #   2. we have added a snapshot policy to the self-play process
        if self.passed_mc and self.meta_reward:
            train_batch = self._update_batch_for_trainable_with_meta_reward(train_batch)
            self.meta_reward_value = 0

        self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
        self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

        # Standardize advantages
        train_batch = standardize_fields(train_batch, ["advantages"])
        # Train
        if self.config._enable_rl_trainer_api:
            train_results = self.trainer_runner.update(train_batch)
        elif self.config.simple_optimizer:
            train_results = train_one_step(self, train_batch)
            train_results = multi_gpu_train_one_step(self, train_batch)

        policies_to_update = list(train_results.keys())

        global_vars = {
            "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
            "num_grad_updates_per_policy": {
                pid: self.workers.local_worker().policy_map[pid].num_grad_updates
                for pid in policies_to_update

        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.num_remote_workers() > 0:
            with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
                from_worker = None
                if self.config._enable_rl_trainer_api:
                    from_worker = self.trainer_runner

        if self.config._enable_rl_trainer_api:
            kl_dict = {
                pid: pinfo[LEARNER_STATS_KEY].get("kl")
                for pid, pinfo in train_results.items()
            # triggers a special update method on RLOptimizer to update the KL values.

            return train_results

        # For each policy: Update KL scale and warn about possible issues
        for policy_id, policy_info in train_results.items():
            # Update KL loss with dynamic scaling
            # for each (possibly multiagent) policy we are training
            kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")

            # Warn about excessively high value function loss
            scaled_vf_loss = (
                    self.config.vf_loss_coeff * policy_info[LEARNER_STATS_KEY]["vf_loss"]
            policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"]
            if (
                    and self.config.get("model", {}).get("vf_share_layers")
                    and scaled_vf_loss > 100
                    "The magnitude of your value function loss for policy: {} is "
                    "extremely large ({}) compared to the policy loss ({}). This "
                    "can prevent the policy from learning. Consider scaling down "
                    "the VF loss by reducing vf_loss_coeff, or disabling "
                    "vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss)
            # Warn about bad clipping configs.
            mean_reward = train_batch.policy_batches[policy_id]["rewards"].mean()
            if (
                    and mean_reward > self.config.vf_clip_param
                self.warned_vf_clip = True
                    f"The mean reward returned from the environment is {mean_reward}"
                    f" but the vf_clip_param is set to {self.config['vf_clip_param']}."
                    f" Consider increasing it for policy: {policy_id} to improve"
                    " value function convergence."

        # Update global vars on local worker as well.

        return train_results

    # @override(PPO)
    def __getstate__(self) -> Dict:
        """Returns current state of Algorithm, sufficient to restore it from scratch.

            The current state dict of this Algorithm, which can be used to sufficiently
            restore the algorithm from scratch without any other information.

        state = super().__getstate__()

        # save the reward buffer
        state['reward_buffer'] = self.reward_buffer
        state['k'] = self.k
        state['meta_reward'] = self.meta_reward
        state['s'] = self.s
        state['win_rate_threshold'] = self.win_rate_threshold
        state['mc_threshold'] = self.mc_threshold
        state['novelty_threshold'] = self.novelty_threshold
        state['random_reward'] = self.return_noise
        state['dominant_strategy'] = self.dominant_strategy
        state['policy_config'] = self.policy_config
        state['dual_path'] = self.dual_path
        state['exploit'] = self.exploit
        state['passed_mc'] = self.passed_mc
        state['n_opponents'] = self.n_opponents
        state['learner_key'] = self.learner_key
        return state

    # @override(PPO)
    def __setstate__(self, state) -> None:
        """Sets the algorithm to the provided state.

            state: The state dict to restore this Algorithm instance to. `state` may
                have been returned by a call to an Algorithm's `__getstate__()` method.
        # TODO (sven): Validate that our config and the config in state are compatible.
        #  For example, the model architectures may differ.
        #  Also, what should the behavior be if e.g. some training parameter
        #  (e.g. lr) changed?

        # restore the reward buffer
        self.reward_buffer = state.pop('reward_buffer')
        self.k = state.pop('k')
        self.meta_reward = state.pop('meta_reward')
        self.s = state.pop('s')
        self.win_rate_threshold = state.pop('win_rate_threshold')
        self.mc_threshold = state.pop('mc_threshold')
        self.novelty_threshold = state.pop('novelty_threshold')
        self.return_noise = state.pop('random_reward')
        self.dominant_strategy = state.pop('dominant_strategy')

        self.policy_config = state.pop('policy_config')
        self.dual_path = state.pop('dual_path')
        self.exploit = state.pop('exploit')

        self.passed_mc = state.pop('passed_mc')

        self.n_opponents = state.pop('n_opponents')
        self.learner_key = state.pop('learner_key')


        self.crw = CRW(self.env_creator,

This should, in theory allow me to resume the experiment and load from a checkpoint. However, tune keeps throwing an error that I eventually figured out would be resolved if the restore path is not to a checkpoint inside the experiment folder, but instead just the results folder.

tuner = tune.Tuner.restore("/mnt/e/PycharmProjects/results/sp_c4/", # NoveltyRewardPPO_LAMBDA_nov_c4_march1000_30iters_mc0.25_novths0.15_dualnets_END_29c26_00000_0_2023-09-06_11-42-42/checkpoint_000015
                           trainable=algo_class, resume_unfinished=True)
 results =

For completeness, here’s my tune call:

results = tune.Tuner(
                local_dir=os.path.join('..', '..', 'results'),
                        "training_iteration": "iter",
                        "timesteps_total": "ts",
                        "policy_reward_mean/main": "rew",
                        "win_rate": "winrate",
                        "league_size": "leaguesize",
                        "batch_meta_reward": "nov",
                    checkpoint_frequency=args.iters // 10,
                    WandbLoggerCallback(project="mapo", entity="aadharna",

Why is this? This seems like a really bad design, no? Now, if I run 100 experiments and decide I want to continue experiment number 14 and double the length, I can’t do so because tune will simply pick the latest one in the results folder to continue.

Similarly, how could I alter the arguments so that if previously it trained for 1000 iterations, but now I want to continue for another 1000 iterations. If i simply attempt to restore that experiment, it’ll keep stop_iters == 1000 and then immediately kill the run.

I guess the right thing to do is to just use the algorithm API, but then I lose out the wandb callback.

Okay. So, I rewrote the code to use the algorithm API and am manually handling the wandb syncing / artifacting. But once I did that and then saved some info that lived in the self-play callback into the algorithm object, then I was able to restore the algorithm using the full checkpoint directory and pass in some updated arguments that would allow the experiment to continue past its termination conditions.

I still think this should be solvable while using the tune API as I now I have to track a bunch of metrics myself, but it works.

1 Like