Repeated obs space transfers obs_flat to the GPU unnecessarily

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity
  • Low: It annoys or frustrates me for a moment.
  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.
  • High: It blocks me to complete my task.

I’m using the Repeated obs space for my RL training, e.g.:

obs_space = Repeated(gym.spaces.Box(-1, 1, shape=x_shape, dtype=np.float32), max_elements)

It works fine, but one thing that I noticed that during the policy training step the whole observation is put on the GPU. I.e. running this as part of the policy’s forward pass:

    def forward(self, input_dict):
        print("obs", input_dict["obs"].values.device)
        print("obs_flat", input_dict["obs_flat"].device)

prints this:

(PPO pid=2799690) obs cuda:0
(PPO pid=2799690) obs_flat cuda:0

The issue with this is that I’m not using obs_flat anywhere, but it has a large size (since it’s 0 padded to max_elements). E.g. if I have x_shape=10 and max_elements=100, then obs_flat is B x 100 x 10. Very often I only have a few elements, so there’s a huge memory and time overhead for unnecessarily transferring obs_flat to the GPU prior to policy.forward.

Is my understanding correct? If so, can this behaviour be disabled?

We calculate input_dict["obs_flatten"] with tf1.keras.layers.Flatten() or torch.flatten(obs, start_dim=1) respectively.
Both operations should not take crazy amounts of time I’d expect them utilize the parameter representations of obs.
It is not part of the documentation, but for example tf.reshape() explicitly says that it uses the underlying data buffer, so there is not copy overhead. I’d expect this to be true for similar operations. Let us know if you disagree!

That’s a good point, I haven’t considered that.
Pytorch only says “this function may return the original object, a view, or copy” (see torch.flatten — PyTorch 1.13 documentation), so it’s a bit ambiguous.
I can benchmark properly later if that’s useful, I just need to dig into the policy code for that.

1 Like

Great! Looking forward to it :slight_smile: