So, I have custom PPOTrainer / PPOConfig classes and have implemented getstate and setstate methods for them.
import logging
import numpy as np
from collections import OrderedDict
from typing import List, Optional, Type, Union, Dict
import math
import time
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.rollout_ops import (
standardize_fields,
)
from ray.rllib.execution.train_ops import (
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.typing import AlgorithmConfigDict, ResultDict
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
from ray.rllib.utils.annotations import override
from rollout_wrapper import CRW
import torch
logger = logging.getLogger(__name__)
class NoveltyRewardPPOConfig(PPOConfig):
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class)
# knn to calculate meta-reward against
self.k = 5
# whether to use the meta-reward or not
self.meta_reward = True
# hyperparameters for the self-play callback
# how often to snapshot the current policy
self.s = 1000
# when applicable, above what win-rate threshold should we snapshot the policy
self.win_rate_threshold = 0.95
# above what win-rate threshold should we switch to the meta-reward
self.mc_threshold = 0.25
# threshold for when a current profile / policy should be saved into the respective buffers
self.novelty_threshold = 0.15
# flag to determine if you return the novelty reward or a random number
self.return_noise = False
# get the dominant strategy
self.dominant_strategy = []
def to_dict(self):
config = super().to_dict()
config["k"] = self.k
config["meta_reward"] = self.meta_reward
config["s"] = self.s
config["win_rate_threshold"] = self.win_rate_threshold
config["mc_threshold"] = self.mc_threshold
config['novelty_threshold'] = self.novelty_threshold
config['random_reward'] = self.return_noise
config['dominant_strategy'] = self.dominant_strategy
return config
def hparams(self, *,
knn: Optional[int] = 5,
meta_reward: Optional[bool] = True,
snapshot_timer: Optional[int] = 25,
novelty_threshold: Optional[float] = 0.5,
win_rate_threshold: Optional[float] = 0.95,
mc_threshold: Optional[float] = 0.8,
random_reward: Optional[bool] = False,
dominant_strategy: Optional[tuple] = (1, 2, 6)) -> "PPOConfig":
"""Returns a copy of this config with the given hyperparameters"""
self.k = knn
self.meta_reward = meta_reward
self.s = snapshot_timer
self.win_rate_threshold = win_rate_threshold
self.mc_threshold = mc_threshold
self.novelty_threshold = novelty_threshold
self.return_noise = random_reward
self.dominant_strategy = list(dominant_strategy)
return self
def update_from_dict(self, config_dict) -> "AlgorithmConfig":
for k in ["k", "meta_reward", "s", "win_rate_threshold", "mc_threshold", 'novelty_threshold',
'random_reward', 'dominant_strategy']:
if k in config_dict:
setattr(self, k, config_dict.pop(k))
return super().update_from_dict(config_dict)
class NoveltyRewardPPO(PPO):
@classmethod
@override(PPO)
def get_default_config(cls) -> AlgorithmConfig:
return NoveltyRewardPPOConfig()
def __init__(self, config=None, env=None, logger_creator=None):
# todo make this a real replay buffer object to take
# advantage of things like prioritized sampling to determine
# which samples to use for meta reward-shaping / learning
self.reward_buffer = OrderedDict()
self.passed_mc = False
self.meta_reward_value = 0
self.n_opponents = 0
self.learner_key = 'main0' if not config['env_config'].get('name', False) == 'connect_four' else 'main'
# configs passed in through the custom_model_config
# cmc = config['multiagent']['policies'][config['env_config']["players_ids"][1]][3]
_config = config.get('model', {}).get('custom_model_config', {})
# _config = config
#
# Assume passed a customPPOConfig object that contains these extra parameters
# knn to calculate meta-reward against
self.k = _config.get("k", 5)
# whether to use the meta-reward or not
self.meta_reward = _config.get("meta_reward", True)
# hyperparameters for the self-play callback
# how often to snapshot the current policy
self.s = _config.get("s", 500)
# when applicable, above what win-rate threshold should we snapshot the policy
self.win_rate_threshold = _config.get("win_rate_threshold", 0.95)
# above what win-rate threshold should we switch to the meta-reward
self.mc_threshold = _config.get("mc_threshold", 0.8)
# threshold for when a current profile / policy should be saved into the respective buffers
self.novelty_threshold = _config.get("novelty_threshold", 0.5)
# if the novelty reward is greater than the threshold, then we save the current profile and policy
self.save_policy_trigger = False
# flag to determine if you return the novelty reward or a random number
self.return_noise = _config.get('random_reward', False)
# get the dominant strategy
self.dominant_strategy = _config.get('dominant_strategy', [])
# at the end so that the self-play callback has access to these attributes above
super().__init__(config, env, logger_creator)
# is the current policy a dual-path policy?
self.dual_path = False
self.exploit = False
if env is not None:
obs_dim = env.observation_space.shape[0]
if obs_dim == 127:
self.dual_path = True
else:
env = self.env_creator(config['env_config'])
obs_dim = env.observation_space.shape[0]
if obs_dim == 127:
self.dual_path = True
# initialize a worker to compute rewards on demand
self.policy_config = {'model': {'custom_model': _config.get('network', 'mlp')}}
# custom rollout worker: band-aid fix for rllib not
# letting me do e.g., algorithm.evaluate(p1, p2)
self.crw = CRW(self.env_creator,
config['env_config'],
self.policy_config)
class BufferElement:
"""
BufferElement is a class that represents a single element of the buffer.
It is used to compute the distance between the current reward and the
rewards in the full buffer.
Each dimension of the reward is represented by a single value that is the
reward of the agent that is playing in that dimension.
"""
def __init__(self, reward_dict, mask, neighbors=10):
assert isinstance(reward_dict, OrderedDict)
self.reward_dict = reward_dict
self.reward_keys = reward_dict.keys()
self.k = neighbors
self.mask = mask
# [r_1, ..., r_n]
self.values = np.array(list(self.reward_dict.values()))
# self.sample_mean = np.mean(self._values)
# self.values = self._values - self.sample_mean
def add(self, k, v):
self.reward_dict[k] = v
self.reward_keys = self.reward_dict.keys()
self.values = np.array(list(self.reward_dict.values()))
self.mask = np.append(self.mask, False)
def __sub__(self, other):
# assume is a list of BufferElement objects
if isinstance(other, OrderedDict) and len(other) >= 2:
buffer_values = np.array([o.values[self.mask] for o in other.values()])
total_reward, credited_rewards = self._assign_credit(buffer_values)
elif isinstance(other, OrderedDict) and len(other) == 1:
buffer_values = np.array([list(other.values())[0].values[self.mask]])
total_reward, credited_rewards = self._assign_credit(buffer_values)
else:
# if len(other) == 0:
total_reward = 1
credited_rewards = [total_reward for _ in range(len(self.reward_keys))]
return total_reward, credited_rewards
def _assign_credit(self, masked_profile_matrix):
# let's do some credit assignment!
profile_matrix = torch.tensor(masked_profile_matrix, requires_grad=True)
current_profile = torch.tensor(self.values[self.mask], requires_grad=True)
torch_mask = torch.tensor(self.mask, requires_grad=False)
difference = current_profile - profile_matrix
dists = torch.norm(difference, dim=1)
sorted_dists, sorted_indices = torch.sort(dists)
important_indices = sorted_indices[:self.k]
indicator = torch.zeros(profile_matrix.shape[0])
indicator[important_indices] = 1
total_reward = torch.sum(dists * indicator) / torch.sum(indicator)
# normalized_total_reward = total_reward / torch_mask.sum()
normalized_total_reward = total_reward / (2 * torch.sqrt(torch.tensor(torch_mask.sum(),
dtype=torch.float32)))
# get the gradient with respect to the current profile vector
normalized_total_reward.backward()
# save total reward to python value
# divide by the number of agents involved in the novelty calculation
# this should normalize the score so it doesn't always go up?
total_reward_python = normalized_total_reward.item()
# push the grad through the softmax
current_profile_grad = current_profile.grad
# ??? should I split the meta-reward into the individual rewards?
# by the softmax of the gradient? Or just provide the gradient?
# version 0: just provide the meta-reward
reward = [total_reward_python for _ in range(len(current_profile_grad))]
# version 1: split the meta-reward into the individual rewards
# by the softmax of the gradient
# current_profile_grad_01 = torch.softmax(current_profile_grad, dim=0)
# grad_reward = current_profile_grad_01 * total_reward
# # get the values back out of the tensor
# grad_reward = grad_reward.detach().numpy()
# version 2: just provide the gradient
# reward = current_profile_grad.detach().numpy()
# version 3: weight the total reward by the gradient
# reward = total_reward_python * current_profile_grad.detach().numpy()
# version 4: weight the total reward by the abs marginal contribution of each agent
# get the nearest neighbor differences and
# compute the mean contribution of each dimension
# mean_contribution_per_snapshot = torch.mean(difference[sorted_indices][:self.k], dim=0)
# # split the scalar value reward into the individual contributions
# percent_contrib = (torch.abs(mean_contribution_per_snapshot) /
# torch.sum(torch.abs(mean_contribution_per_snapshot))).detach().numpy()
# reward = total_reward_python * percent_contrib
reward = self._add_in_zeros(reward)
return total_reward_python, reward
def _add_in_zeros(self, reward):
foo = []
r_index = 0
for m in self.mask:
if m:
foo.append(reward[r_index])
r_index += 1
else:
foo.append(0)
return np.array(foo)
def __getstate__(self):
state = {}
# make dict from ordered dict
state['reward_keys'] = list(self.reward_keys)
state['reward_values'] = self.values.tolist()
state['mask'] = self.mask.tolist()
state['k'] = self.k
return state
def __setstate__(self, state):
keys = state['reward_keys']
values = state['reward_values']
self.reward_dict = OrderedDict(zip(keys, values))
self.reward_keys = self.reward_dict.keys()
self.values = np.array(list(self.reward_dict.values()))
self.mask = np.array(state['mask'])
self.k = state['k']
def __repr__(self):
return f"BufferElement({self.reward_dict}, {self.mask}, {self.k})"
def _update_batch_for_trainable_with_meta_reward(self, train_batch: "MultiAgentBatch"):
"""
This function updates the batch for the main0 policy with the meta reward.
The meta reward is the distance between the current reward and the rewards
in the buffer.
"""
policy_ids_in_batch = list(train_batch.policy_batches.keys())
snap_ids_in_batch = [pid for pid in policy_ids_in_batch if not pid == self.learner_key]
main_agent_state = self.get_policy(self.learner_key).get_state()
main_batch = train_batch.policy_batches[self.learner_key]
snap_batches = {pid: train_batch.policy_batches[pid] for pid in snap_ids_in_batch}
# split batches by episode
episodic_main_batches = main_batch.split_by_episode()
episodic_snap_batches = {pid: snap_batches[pid].split_by_episode() for pid in snap_ids_in_batch}
# if self.dual_path:
# # split apart batches used for optimization
# opt_batches = SampleBatch()
# nov_batches = SampleBatch()
# for ep_batch in episodic_main_batches:
# # exploitation episode
# if ep_batch[ep_batch.OBS][0][-1] == 0:
# opt_batches = opt_batches.concat(ep_batch)
# # exploration episode
# else:
# nov_batches = nov_batches.concat(ep_batch)
#
# episodic_main_batches = nov_batches.split_by_episode()
# calculate average score of each snapshop policy using total reward
snap_scores = {pid: np.mean([ep_batch[ep_batch.REWARDS].sum() for ep_batch in episodic_snap_batches[pid]])
for pid in snap_ids_in_batch}
# fill in missing values in the buffer sample
# this uses current main0
# if this wasn't in the training batch, mask it out of the meta_reward calculation
# but still fill the value in for use later!!
comparison_mask = []
for n in list(np.arange(1, self.n_opponents + 1)):
pid = 'random_r0' if n == 0 else f'main_v{n}'
if pid not in snap_scores:
opponent_pol = self.get_policy(pid)
if opponent_pol is not None:
opponent_state = opponent_pol.get_state()
else:
opponent_state = self.get_policy('random').get_state()
rewards = self.crw.get_discounted_rewards(opponent_state, main_agent_state)
try:
snap_scores[pid] = rewards['random_r0']
except KeyError:
# 0 = random
# 1 = main
snap_scores[pid] = rewards[0]
finally:
snap_scores[pid] = 0
comparison_mask.append(True)
comparison_mask = np.array(comparison_mask)
# sort snapshot rewards by name e.g., main_v2
sorted_snap_scores = OrderedDict()
for k in sorted(snap_scores, key=lambda x: int(x.split('_')[1][1:])):
sorted_snap_scores[k] = snap_scores[k]
# calculate distance between snapshot policy rewards and buffer examples
snap_sample_reward = self.BufferElement(sorted_snap_scores, comparison_mask, neighbors=self.k)
total_reward, credited_rewards = snap_sample_reward - self.reward_buffer
# save the meta reward value for logging
self.meta_reward_value = total_reward
# add to buffer if total reward is greater than novelty threshold
if total_reward >= self.novelty_threshold:
self.reward_buffer[self.iteration] = snap_sample_reward
self.save_policy_trigger = True
main_policy = self.get_policy(self.learner_key)
# set rewards in main0 batch
# give each batch the percent of the total reward that the opponent contributed
ordered_keys = sorted_snap_scores.keys()
recombined_main_batch = SampleBatch()
for ep_batch in episodic_main_batches:
episode_id = ep_batch[ep_batch.EPS_ID][0]
# find which opponent was used in this episode
opponent_id = None
for pid in snap_ids_in_batch:
for ep in episodic_snap_batches[pid]:
if ep[ep.EPS_ID][0] == episode_id:
opponent_id = pid
break
if opponent_id is not None:
break
if opponent_id is None or opponent_id == self.learner_key:
# what the fuck, this shouldn't be possible
continue
reward_index = list(ordered_keys).index(opponent_id)
ep_batch[ep_batch.REWARDS] = np.zeros_like(ep_batch[ep_batch.REWARDS])
if self.return_noise:
ep_batch[ep_batch.REWARDS][-1] = torch.randn(1)
else:
ep_batch[ep_batch.REWARDS][-1] = credited_rewards[reward_index]
# recalculate advantages using new rewards
# use the main policy since that's the batch we're updating
updated_batch = compute_gae_for_sample_batch(policy=main_policy,
sample_batch=ep_batch)
recombined_main_batch = recombined_main_batch.concat(updated_batch)
# if self.dual_path:
# # add optimization batches back in
# recombined_main_batch = recombined_main_batch.concat(opt_batches)
# place updated version back into the train_batch we were sent
train_batch.policy_batches[self.learner_key] = recombined_main_batch
return train_batch
def used_dom_to_win(self, train_batch: "MultiAgentBatch") -> float:
# check if we used the dominant strategy
# calculate the number of rollouts where we won via the dominant strategy
main_batches = train_batch.policy_batches[self.learner_key].split_by_episode()
# bar = train_batch.policy_batches['random'].split_by_episode()
dominant_strategy_wins = 0
for batch in main_batches:
actions = batch[batch.ACTIONS]
# calculate if we used the dominant strategy
# actions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 5, 3, 4, 6, 6, 1, 2, 6]
# dominant_strategy = [1, 2, 6]
# check if the dominant strategy (in the order given) occurs in the actions list
# if it does, we used the dominant strategy
# if it doesn't, we didn't use the dominant strategy
start = len(self.dominant_strategy)
end = len(actions)
for i in range(end - start + 1):
subset = np.array(actions[i:i + start])
if subset.size > 0:
match = subset == self.dominant_strategy
if np.all(match):
dominant_strategy_wins += 1
break
return dominant_strategy_wins / len(main_batches)
@override(PPO)
def training_step(self) -> ResultDict:
# Collect SampleBatches from sample workers until we have a full batch.
if self.config.count_steps_by == "agent_steps":
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
max_agent_steps=self.config.train_batch_size,
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_env_steps=self.config.train_batch_size
)
train_batch = train_batch.as_multi_agent()
self.percent_dom_used = self.used_dom_to_win(train_batch)
self.save_policy_trigger = False
# update batch for main0 policy with meta reward
# from the other snapshot policies
# todo put in an if statement to only do this if:
# 2. we have added a snapshot policy to the self-play process
if self.passed_mc and self.meta_reward:
train_batch = self._update_batch_for_trainable_with_meta_reward(train_batch)
else:
self.meta_reward_value = 0
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
# Standardize advantages
train_batch = standardize_fields(train_batch, ["advantages"])
# Train
if self.config._enable_rl_trainer_api:
train_results = self.trainer_runner.update(train_batch)
elif self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
policies_to_update = list(train_results.keys())
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
"num_grad_updates_per_policy": {
pid: self.workers.local_worker().policy_map[pid].num_grad_updates
for pid in policies_to_update
},
}
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.num_remote_workers() > 0:
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
from_worker = None
if self.config._enable_rl_trainer_api:
from_worker = self.trainer_runner
self.workers.sync_weights(
from_worker=from_worker,
policies=list(train_results.keys()),
global_vars=global_vars,
)
if self.config._enable_rl_trainer_api:
kl_dict = {
pid: pinfo[LEARNER_STATS_KEY].get("kl")
for pid, pinfo in train_results.items()
}
# triggers a special update method on RLOptimizer to update the KL values.
self.trainer_runner.additional_update(kl_values=kl_dict)
return train_results
# For each policy: Update KL scale and warn about possible issues
for policy_id, policy_info in train_results.items():
# Update KL loss with dynamic scaling
# for each (possibly multiagent) policy we are training
kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")
self.get_policy(policy_id).update_kl(kl_divergence)
# Warn about excessively high value function loss
scaled_vf_loss = (
self.config.vf_loss_coeff * policy_info[LEARNER_STATS_KEY]["vf_loss"]
)
policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"]
if (
log_once("ppo_warned_lr_ratio")
and self.config.get("model", {}).get("vf_share_layers")
and scaled_vf_loss > 100
):
logger.warning(
"The magnitude of your value function loss for policy: {} is "
"extremely large ({}) compared to the policy loss ({}). This "
"can prevent the policy from learning. Consider scaling down "
"the VF loss by reducing vf_loss_coeff, or disabling "
"vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss)
)
# Warn about bad clipping configs.
train_batch.policy_batches[policy_id].set_get_interceptor(None)
mean_reward = train_batch.policy_batches[policy_id]["rewards"].mean()
if (
log_once("ppo_warned_vf_clip")
and mean_reward > self.config.vf_clip_param
):
self.warned_vf_clip = True
logger.warning(
f"The mean reward returned from the environment is {mean_reward}"
f" but the vf_clip_param is set to {self.config['vf_clip_param']}."
f" Consider increasing it for policy: {policy_id} to improve"
" value function convergence."
)
# Update global vars on local worker as well.
self.workers.local_worker().set_global_vars(global_vars)
return train_results
# @override(PPO)
def __getstate__(self) -> Dict:
"""Returns current state of Algorithm, sufficient to restore it from scratch.
Returns:
The current state dict of this Algorithm, which can be used to sufficiently
restore the algorithm from scratch without any other information.
"""
state = super().__getstate__()
# save the reward buffer
state['reward_buffer'] = self.reward_buffer
state['k'] = self.k
state['meta_reward'] = self.meta_reward
state['s'] = self.s
state['win_rate_threshold'] = self.win_rate_threshold
state['mc_threshold'] = self.mc_threshold
state['novelty_threshold'] = self.novelty_threshold
state['random_reward'] = self.return_noise
state['dominant_strategy'] = self.dominant_strategy
state['policy_config'] = self.policy_config
state['dual_path'] = self.dual_path
state['exploit'] = self.exploit
state['passed_mc'] = self.passed_mc
state['n_opponents'] = self.n_opponents
state['learner_key'] = self.learner_key
return state
# @override(PPO)
def __setstate__(self, state) -> None:
"""Sets the algorithm to the provided state.
Args:
state: The state dict to restore this Algorithm instance to. `state` may
have been returned by a call to an Algorithm's `__getstate__()` method.
"""
# TODO (sven): Validate that our config and the config in state are compatible.
# For example, the model architectures may differ.
# Also, what should the behavior be if e.g. some training parameter
# (e.g. lr) changed?
# restore the reward buffer
self.reward_buffer = state.pop('reward_buffer')
self.k = state.pop('k')
self.meta_reward = state.pop('meta_reward')
self.s = state.pop('s')
self.win_rate_threshold = state.pop('win_rate_threshold')
self.mc_threshold = state.pop('mc_threshold')
self.novelty_threshold = state.pop('novelty_threshold')
self.return_noise = state.pop('random_reward')
self.dominant_strategy = state.pop('dominant_strategy')
self.policy_config = state.pop('policy_config')
self.dual_path = state.pop('dual_path')
self.exploit = state.pop('exploit')
self.passed_mc = state.pop('passed_mc')
self.n_opponents = state.pop('n_opponents')
self.learner_key = state.pop('learner_key')
super().__setstate__(state)
self.crw = CRW(self.env_creator,
self.config['env_config'],
self.policy_config)
This should, in theory allow me to resume the experiment and load from a checkpoint. However, tune keeps throwing an error that I eventually figured out would be resolved if the restore path is not to a checkpoint inside the experiment folder, but instead just the results folder.
tuner = tune.Tuner.restore("/mnt/e/PycharmProjects/results/sp_c4/", # NoveltyRewardPPO_LAMBDA_nov_c4_march1000_30iters_mc0.25_novths0.15_dualnets_END_29c26_00000_0_2023-09-06_11-42-42/checkpoint_000015
trainable=algo_class, resume_unfinished=True)
results = tuner.fit()
For completeness, here’s my tune call:
results = tune.Tuner(
algo_class,
param_space=alg_config,
run_config=air.RunConfig(
name='sp_c4',
stop=stop,
local_dir=os.path.join('..', '..', 'results'),
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "iter",
"timesteps_total": "ts",
"policy_reward_mean/main": "rew",
"win_rate": "winrate",
"league_size": "leaguesize",
"batch_meta_reward": "nov",
},
sort_by_metric=True,
),
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True,
checkpoint_frequency=args.iters // 10,
),
callbacks=[
WandbLoggerCallback(project="mapo", entity="aadharna",
save_checkpoints=True,
api_key="key")
]
),
).fit()
Why is this? This seems like a really bad design, no? Now, if I run 100 experiments and decide I want to continue experiment number 14 and double the length, I can’t do so because tune will simply pick the latest one in the results folder to continue.
Similarly, how could I alter the arguments so that if previously it trained for 1000 iterations, but now I want to continue for another 1000 iterations. If i simply attempt to restore that experiment, it’ll keep stop_iters
== 1000 and then immediately kill the run.
I guess the right thing to do is to just use the algorithm API, but then I lose out the wandb callback.