When using an RLlib trainer, I noticed that a function handle is provided to reset_config
however, this function is simply a return False
.
This says to me that we, the user, are meant to inherit from Trainable and then override that class. However, if we have a remote class already that contains a trainer object, what is the correct way to handle the overriding?
@ray.remote
class remote_trainer:
def __init__(self):
self.trainer = PPOTrainer(ppoConfig, env_name)
...
def train(self):
result = self.trainer.train()
self.logger.log(result)
return result
def evaluate(self):
#run rollout with trainer
return rollout_statistics
def reset(new_config):
self.trainer.reset_config(new_config) # this will return False.
If you were only using one algorithm, then I could see just making my remote-trainer class inherent from e.g. PPOTrainer. But if we want that to be optimization-algorithm agnostic, then where should I override the reset function? The hacky solution I can think of is just to all-out replace the function e.g.
@ray.remote
class remote_trainer:
def __init__(self, update_config_function):
self.trainer = PPOTrainer(ppoConfig, env_name)
self.trainer.reset_config = update_config_function
...
But that would break things since we lose the self ref that way.
I am at a loss as to what class I should be overriding here since, in my application, I want the user to be able to select from any of the various RLlibTrainers.
1 Like
Hey @aadharna , great question!
You could also just create a new PPOTrainer sub-class by doing this:
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
from ray.tune import Trainable
class ResetConfigOverride:
@override(Trainable)
def reset_config(self, new_config):
# do something with new_config
return True # <- signals successful reset
MyPPOTrainerWithCustomResetConfig = add_mixins(PPOTrainer, ResetConfigOverride)
1 Like
To be thorough, would this work as a generic override provided that the ResetConfigOverride class handles any necessary optimization algorithm specific information?
def get_default_trainer_config_and_constructor(opt_algo):
if opt_algo == "OpenAIES":
return es.DEFAULT_CONFIG.copy(), es.ESTrainer
elif opt_algo == "PPO":
return ppo.DEFAULT_CONFIG.copy(), ppo.PPOTrainer
elif opt_algo == 'MAML':
return maml.DEFAULT_CONFIG.copy(), maml.MAMLTrainer
elif opt_algo == 'DDPG':
return ddpg.DEFAULT_CONFIG.copy(), ddpg.DDPGTrainer
elif opt_algo == 'DQN':
return dqn.DEFAULT_CONFIG.copy(), dqn.DQNTrainer
elif opt_algo == 'SAC':
return sac.DEFAULT_CONFIG.copy(), sac.SACTrainer
elif opt_algo == 'IMPALA':
return impala.DEFAULT_CONFIG.copy(), impala.ImpalaTrainer
else:
raise ValueError('Pick another opt_algo')
And then in a class I have that parses argument files
...
self.trainer_config, self.trainer_constr = get_default_trainer_config_and_constructor(self.file_args.opt_algo)
if self.file_args.custom_trainer_config_override:
self.trainer_constr = add_mixins(self.trainer_constr, ResetConfigOverride)
The parser class can send the trainer constructor and config_dict off to be actualized to make a remote trainer where that remote trainer can call the “reset_config” function anytime we need to change e.g. hyperparameters.