How to generically override the reset_config function?

When using an RLlib trainer, I noticed that a function handle is provided to reset_config however, this function is simply a return False.

This says to me that we, the user, are meant to inherit from Trainable and then override that class. However, if we have a remote class already that contains a trainer object, what is the correct way to handle the overriding?

@ray.remote
class remote_trainer:
    def __init__(self):
        self.trainer = PPOTrainer(ppoConfig, env_name)
        ...

    def train(self):
        result = self.trainer.train()
        self.logger.log(result)
        return result

    def evaluate(self):
        #run rollout with trainer
        return rollout_statistics

    def reset(new_config):
        self.trainer.reset_config(new_config) # this will return False.

If you were only using one algorithm, then I could see just making my remote-trainer class inherent from e.g. PPOTrainer. But if we want that to be optimization-algorithm agnostic, then where should I override the reset function? The hacky solution I can think of is just to all-out replace the function e.g.

@ray.remote
class remote_trainer:
    def __init__(self,  update_config_function):
        self.trainer = PPOTrainer(ppoConfig, env_name)
        self.trainer.reset_config = update_config_function
    ...

But that would break things since we lose the self ref that way.

I am at a loss as to what class I should be overriding here since, in my application, I want the user to be able to select from any of the various RLlibTrainers.

1 Like

Hey @aadharna , great question! :slight_smile:
You could also just create a new PPOTrainer sub-class by doing this:

from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
from ray.tune import Trainable

class ResetConfigOverride:
    @override(Trainable)
    def reset_config(self, new_config):
        # do something with new_config
        return True  # <- signals successful reset

MyPPOTrainerWithCustomResetConfig = add_mixins(PPOTrainer, ResetConfigOverride)
1 Like

To be thorough, would this work as a generic override provided that the ResetConfigOverride class handles any necessary optimization algorithm specific information?

def get_default_trainer_config_and_constructor(opt_algo):
    if opt_algo == "OpenAIES":
        return es.DEFAULT_CONFIG.copy(), es.ESTrainer
    elif opt_algo == "PPO":
        return ppo.DEFAULT_CONFIG.copy(), ppo.PPOTrainer
    elif opt_algo == 'MAML':
        return maml.DEFAULT_CONFIG.copy(), maml.MAMLTrainer
    elif opt_algo == 'DDPG':
        return ddpg.DEFAULT_CONFIG.copy(), ddpg.DDPGTrainer
    elif opt_algo == 'DQN':
        return dqn.DEFAULT_CONFIG.copy(), dqn.DQNTrainer
    elif opt_algo == 'SAC':
        return sac.DEFAULT_CONFIG.copy(), sac.SACTrainer
    elif opt_algo == 'IMPALA':
        return impala.DEFAULT_CONFIG.copy(), impala.ImpalaTrainer
    else:
        raise ValueError('Pick another opt_algo')

And then in a class I have that parses argument files

...
self.trainer_config, self.trainer_constr = get_default_trainer_config_and_constructor(self.file_args.opt_algo)
if self.file_args.custom_trainer_config_override:
    self.trainer_constr = add_mixins(self.trainer_constr, ResetConfigOverride)

The parser class can send the trainer constructor and config_dict off to be actualized to make a remote trainer where that remote trainer can call the “reset_config” function anytime we need to change e.g. hyperparameters.