# How to concat rollout batches before computing GAE?

Hello,

I would like to know how to concat rollout batches before computing GAE.

I’m trying to change PPO to use the average reward setting instead of the discounted formulation.
In other words, I want to compute the TD-errors as

where is an estimation of the average reward for policy $\pi$, independent of starting state $S_0$.
Everything else stays exactly the same.

In my understanding, the simplest way to implemented it is to re-compute all rewards after rollout collection by subtracting the average of collected rewards. In other words, I would like to perform the following

1. perform rollouts (e.g., compute $\pi(a|s)$, $v(s)$ and env#step)
2. concat batches
3. compute average reward (e.g. )
4. re-compute rewards (e.g., )
5. compute GAE as usual
6. back prop as usual

My issue is that, by construction, step 5. occurs before step 2. and I don’t see any way to reverse them. How can I implement such algorithm with RLLib? Is that any way to overwrite PPO’s default behavior?

Before you passing in the training batch to the loss function, you have to postprocess your training batch. I’m pretty sure you will have value function predictions in the batch b/c it is in the policy code for both PPO torch and tensorflow policy. (There you can add a new key to the train batch dict)

To do this, override postprocess_fn, which is better described here: ray/policy_template.py at master · ray-project/ray · GitHub

This is precisely what I did as a first version…

As I explained in the question, the problem with this strategy is that not all data is provided to postprocess_fn. The SampleBatch input to postprocess_fn contains only the data from the current worker. I would like to have access to the data from all rollout workers to compute the average reward, before computing GAE.

The painful way is to modify the execution plan right about here :

from ray.rllib.agents.ppo import PPOTrainer

def my_execution_plan:
pass


I was reading the code, and I believe that the only way to work around this double computation is to overwrite the postprocess_fn as well (so it does not do any computation on the first pass on data) and then compute GAE later, after batch concatanation.