How severe does this issue affect your experience of using Ray?
- None: Just asking a question out of curiosity
- Low: It annoys or frustrates me for a moment.
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
- High: It blocks me to complete my task.
I’m using the Repeated obs space for my RL training, e.g.:
obs_space = Repeated(gym.spaces.Box(-1, 1, shape=x_shape, dtype=np.float32), max_elements)
It works fine, but one thing that I noticed that during the policy training step the whole observation is put on the GPU. I.e. running this as part of the policy’s forward pass:
def forward(self, input_dict):
print("obs", input_dict["obs"].values.device)
print("obs_flat", input_dict["obs_flat"].device)
prints this:
(PPO pid=2799690) obs cuda:0
(PPO pid=2799690) obs_flat cuda:0
The issue with this is that I’m not using obs_flat
anywhere, but it has a large size (since it’s 0 padded to max_elements
). E.g. if I have x_shape=10
and max_elements=100
, then obs_flat
is B x 100 x 10
. Very often I only have a few elements, so there’s a huge memory and time overhead for unnecessarily transferring obs_flat
to the GPU prior to policy.forward
.
Is my understanding correct? If so, can this behaviour be disabled?