Hi,
I would like to alter the actions taken by the agent during training. I thought that creating a custom trainer that subclasses the rllib trainer would be the way to go. I create a decorator for compute_single_action
and created a CustomPPOTrainer
like so:
from ray.rllib.agents.ppo import PPOTrainer
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
TensorType, TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
class CustomPPOTrainer(PPOTrainer):
def compute_custom_actions(regular_compute_actions_func):
def inner(self,
observation,
state,
prev_action,
prev_reward,
info,
policy_id,
full_fetch,
explore, timestep, episodes, unsquash_actions, clip_actions):
if max(observation["rho"]) > 0.95: # if this holds, compute actions as usual
return regular_compute_actions_func(self,
observation,
state,
prev_action,
prev_reward,
info,
policy_id,
full_fetch,
explore,
timestep)
else:
return 0 # the action I would like to return in case condition is not met
return inner
@compute_custom_actions
@override(PPOTrainer)
def compute_single_action( self,
observation: Optional[TensorStructType] = None,
state: Optional[List[TensorStructType]] = None,
*,
prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[float] = None,
info: Optional[EnvInfoDict] = None,
input_dict: Optional[SampleBatch] = None,
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episode= None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
# Kwargs placeholder for future compatibility.
**kwargs):
return PPOTrainer.compute_single_action(self, observation,
state,
prev_action,
prev_reward,
info,
policy_id,
full_fetch,
explore,
timestep)
When I call the trainer it trains but my change has no effect:
trainer = CustomPPOTrainer(env=my_env,
config=config)
for step in range(1000):
result = trainer.train()
if step % 5 == 0:
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
Could you point me to where I am going wrong and how I can achieve desired result?