Hi,
I am trying to apply reward shaping after collecting rollout data but before PPO starts training the network. According to the official RLlib documentation (see screenshot), the recommended way is to use the on_postprocess_trajectory callback.
However, when I implemented the callback class (see minimal example below), the method on_postprocess_trajectory is never called during training.
The following is my test code:
class InspectBatchCallback(RLlibCallback):
def __init__(self):
super().__init__()
self.episode_cnt = 0
def on_environment_created(self, *, env_runner,
metrics_logger: Optional[MetricsLogger] = None,
env: gym.Env, **kwargs, ):
print("ENVIRONMENT CREATED")
def on_algorithm_init(self, *, algorithm,
metrics_logger: Optional[MetricsLogger] = None, **kwargs, ):
print("ALGORITHM INIT")
def on_episode_created(self, *, episode: EpisodeV2, **kwargs):
print("EPISODE CREATED")
def on_episode_start(self, *, worker, base_env,
policies: Optional[Dict[PolicyID, Policy]] = None, **kwargs):
print("EPISODE START")
def on_episode_step(self, *, episode, env, **kwargs):
print("EPISODE STEP CALLED")
def on_episode_end(self, *, episode, metrics_logger, **kwargs):
self.episode_cnt += 1
print("EPISODE END")
print(f"Episode {self.episode_cnt} finished.")
print()
def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
print("ON LEARN ON BATCH CALLED")
def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id,
policies, postprocessed_batch, original_batches, **kwargs):
print("POSTPROCESS TRAJECTORY CALLED")
def on_train_result(self, *, algorithm, result, **kwargs):
print("TRAIN RESULT CALLED")
ray.shutdown()
ray.init(ignore_reinit_error=True)
config = (
PPOConfig()
.environment("CartPole-v1")
.framework("torch")
.env_runners(
num_env_runners=2,
batch_mode="complete_episodes"
)
.callbacks(
InspectBatchCallback
)
.learners(
num_learners=1,
)
)
tune.Tuner(
"PPO",
param_space=config,
run_config=tune.RunConfig(
name="ppo_cartpole_exp",
verbose=1,
stop={
"training_iteration": 3,
"time_total_s": 10,
},
)
).fit()