I have a question that should be very straightforward, but is causing me a surprising amount of difficulty, perhaps because I’ve missed something important. I’ve created a very simple environment consisting of two timesteps, followed by a reward, and am trying to train PPO to predict value accurately on this environment with lambda=0.
State A -action-> State B -action-> reward+termination
Code reproducing a run of two episodes of this environment under PPO is below:
episode_lens = [3, 3]
vfps = [0.0, 0.95, 0.95, 0.0, 0.95, 0.95]
rewards = [ 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]
terminateds = [False, True, True, False, True, True]
truncateds = [False, False, False, False, False, False]
gamma = 0.99
lambda_ = 0.0
compute_value_targets(
values=vfps,
rewards=unpad_data_if_necessary(
episode_lens,
np.array(rewards),
),
terminateds=unpad_data_if_necessary(
episode_lens,
np.array(terminateds),
),
truncateds=unpad_data_if_necessary(
episode_lens,
np.array(truncateds),
),
gamma=gamma,
lambda_=lambda_,
)
(The extra timestep at the end of each episode is added by AddOneTsToEpisodesAndTruncate)
Code output:
array([0., 1., 0., 0., 1., 0.], dtype=float32)
Shouldn’t the value target at positions 0 and 3 be equal to gamma*0.95 instead of zero? Setting lambda_ = 0 should cause the value targets to be set purely based on next state value, if I understand properly.
Looking deeper into the code, should AddOneTsToEpisodesAndTruncate remove the terminated flag from the timesteps before the new one that it adds? For example, suppose the following environment:
A -(+0.0)-> B -(+0.0)-> C -(+1.0)-> TERMINATE
-(+0.0)-> TERMINATE
The value function, with lambda set to zero, would converge to:
V(A) = 0.0
V(B) = 0.0
V(C) = 1.0
I would expect value bootstrapping, run correctly in conjunction with a random policy, to converge to:
V(A) = gamma**2 / 2
V(B) = gamma
V(C) = 1.0
This discrepancy seems to mean that a PPO agent trained on this environment would fail to converge, because the advantage of reaching state B instead of terminating would be zero.
Am I misunderstanding something about the implementation, or is this a bug?