TorchMultiCategorical with logits calculated in the constructor


I am experimenting with autoregressive action distributions - I used RLlib examples as a starter kit.

I am trying to solve a dummy environment, which outputs a random number, e.g. 5 and the goal is to provide two numbers, which when added to the observation, outputs a target value (e.g. 10). You can find the complete definition of the environment here.

While working on this topic, I found that AR models work much worse than a naive approach, which is completely counter-intuitive. I started digging deeper and created FakeTorchMultiCategorical action distribution, which mimics behavior of TorchMultiCategorical, but instead of accepting concatenated logits of both actions, it accepts internal features of the model, and the logits are calculated and concatenated inside the constructor (see here). I also verified that the algorithm I am using, PPO, isn’t doing anything strange between model inference and action distribution instantiation. So, inside I found:

logits, state = model(train_batch)
curr_action_dist = dist_class(logits, model)

which looks fine. I’ve just moved logits calculation from model to dist_class, in some sense similar to one of your examples.

This, however, completely breaks the training and the agent can no longer achieve a decent performance. You can see the training plots here. The plots (e.g. losses and entropy) of the baseline run are much better than the remaining ones. The question is, why there is a difference between the baseline run and the fake_multicategorical run?

Do you have any idea what I am doing wrong?

One idea I see is that when we construct:

prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model)

after the first SGD iteration, we actually calculating final logits using the updated model, but the old hidden state (features/context). However, this is not used in training - just to log KL metrics:

action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl = reduce_mean_valid(action_kl)
# ...
policy._mean_kl = mean_kl


This part seems off:

You pass the features through the autoregressive action layers in the forward call of the model. Then rllib will end up passing those outputs back through again inside the action distribution.

No - only the BaselineModel has these passes in its forward method, and this model is supposed to be working with the built-in TorchMultiCategorical distribution. All the remaining models (including one for fake_multicategorical) have forward method overridden and only features/context are being calculated there (see here).

In short words: there is my custom implementation (FakeTorchMultiCategorical) of the built-in TorchMultiCategorical action distribution. The only difference between these two is that the former expects model to output context features and calculates final logits in the constructor, while the latter has logits calculation done in the model. In theory, they should produce exactly the same output. But they don’t, and I can’t understand why.


Sorry my fault I misread.

BTW if you kl_coeff is > 0 then the action_kl is used in the loss here:

Thanks @mannyv, you’re right. By zeroing kl_coeff I have exactly the same plots (however, the performance drastically dropped, but this might be hopefully restored with clip_param :)).

Nevertheless, it would be worth fixing the prev_action_dist, as even the official example probably suffers with this issue. Ideally, there should be the old_model available in ppo_surrogate_loss, or do you have a better idea? I am fine preparing a PR.


There is an open issue for this but it has not gotten much attention yet.