Controlling compute_actions during training

Hi,

I would like to alter the actions taken by the agent during training. I thought that creating a custom trainer that subclasses the rllib trainer would be the way to go. I create a decorator for compute_single_action and created a CustomPPOTrainer like so:

from ray.rllib.agents.ppo import PPOTrainer 
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
    PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
    TensorType, TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch


class CustomPPOTrainer(PPOTrainer):
    def compute_custom_actions(regular_compute_actions_func):
        def inner(self,
                observation,
                state,
                prev_action,
                prev_reward,
                info,
                policy_id, 
                full_fetch,
                explore, timestep, episodes, unsquash_actions, clip_actions):

            if max(observation["rho"]) > 0.95: # if this holds, compute actions as usual
                
                return regular_compute_actions_func(self,
                                                    observation,
                                                    state,
                                                    prev_action,
                                                    prev_reward,
                                                    info,
                                                    policy_id, 
                                                    full_fetch,
                                                    explore,
                                                    timestep)
            else:
                return 0 # the action I would like to return in case condition is not met
        return inner

    @compute_custom_actions
    @override(PPOTrainer)
    def compute_single_action( self,
            observation: Optional[TensorStructType] = None,
            state: Optional[List[TensorStructType]] = None,
            *,
            prev_action: Optional[TensorStructType] = None,
            prev_reward: Optional[float] = None,
            info: Optional[EnvInfoDict] = None,
            input_dict: Optional[SampleBatch] = None,
            policy_id: PolicyID = DEFAULT_POLICY_ID,
            full_fetch: bool = False,
            explore: Optional[bool] = None,
            timestep: Optional[int] = None,
            episode= None,
            unsquash_action: Optional[bool] = None,
            clip_action: Optional[bool] = None,
            # Kwargs placeholder for future compatibility.
            **kwargs):
            return PPOTrainer.compute_single_action(self, observation,
                                                    state,
                                                    prev_action,
                                                    prev_reward,
                                                    info,
                                                    policy_id, 
                                                    full_fetch,
                                                    explore,
                                                    timestep)

When I call the trainer it trains but my change has no effect:


trainer = CustomPPOTrainer(env=my_env,
                          config=config)
for step in range(1000):
    result = trainer.train()
    
    if step % 5 == 0:
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)

Could you point me to where I am going wrong and how I can achieve desired result?