VTrace: Understanding shape of batches

Hi. I’m trying to understand how shapes of different batches of data are related in the VTrace Torch policy, in IMPALA.

In particular, how the train batch of observations is related to the actions and rewards batch.

In vtrace_torch_policy.py, the method build_vtrace_loss contains the train_batch, a dictionary with a 'obs' key, that is a batch of the agent observations while training (train_batch['obs'] has shape 88 x SHAPE_OF_OBS_PER_TIMESTEP).

I assume this 88 comes from the rollout_fragment_length parameter, which is set to 22.

Lets assume that SHAPE_OF_OBS_PER_TIMESTEP is 1, so each agent observation has 1 dimension, and the batch of observations is 88 x 1.

Now, in vtrace_torch.py, there is the method from_importance_weights that contains something related to an estimation of the Q value:
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

This deltas tensor is NOT the same shape as the train_batch['obs'] from before, but instead has shape 21 x 4.

So the batch of rewards, values, discounts and values_t_plus_1 inside vtrace_torch.py has shape 21 x 4, but the batch of observations (and actions and rewards) in vtrace_torch_policy.py have shape 88.

I assume that, because 22 times 4 is 88, there is a “missing” value in the batch of vtrace_torch.py, but I’m not sure why, nor why the actual shape is 21 x 4: why 4 columns? is the “missing” value related to the first timestep? as we cannot estimate Q in time t=0, before having any observation?

My goal is to do some computation between the batchs from the two files (observations from one, Q estimates from the other), but having these different shapes, I don’t know how they are related to each other.