@renos So if I understand correctly you essentially want to add auxiliary losses to your policy loss to jointly train your visual representation with the policy specific parameters?
RLlib provides a custom_loss() hook that allows such use cases. You can take a look at rllib/examples/custom_loss_and_metric.py to see how it’s used in practice. Ideally you want to use loss_inputs to get access to the train_batch sampled from the replay buffer. Within that function you can compute reward prediction loss or any other auxiliary losses that would improve the representation.