Repeated obs space transfers obs_flat to the GPU unnecessarily

I’m using the Repeated obs space for my RL training, e.g.:

obs_space = Repeated(gym.spaces.Box(-1, 1, shape=x_shape, dtype=np.float32), max_elements)

It works fine, but one thing that I noticed that during the policy training step the whole observation is put on the GPU. I.e. running this as part of the policy’s forward pass:

    def forward(self, input_dict):
        print("obs", input_dict["obs"].values.device)
        print("obs_flat", input_dict["obs_flat"].device)

prints this:

(PPO pid=2799690) obs cuda:0
(PPO pid=2799690) obs_flat cuda:0

The issue with this is that I’m not using obs_flat anywhere, but it has a large size (since it’s 0 padded to max_elements). E.g. if I have x_shape=10 and max_elements=100, then obs_flat is B x 100 x 10. Very often I only have a few elements, so there’s a huge memory and time overhead for unnecessarily transferring obs_flat to the GPU prior to policy.forward.

Is my understanding correct? If so, can this behaviour be disabled?