Input_dict's data structure

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Dear Ray Community,

Hi. Thanks for your attention to this question.

I’m new to the RLlib and I’m really confused about the data structure of the input_dict to my customized model.

When I was looking into the rllib.examples.models.rnn_model.py, it seems that the dimension of one attribute inside the input_dict[“obs”] is [Batch_size * Sequence_length, 1]. This deduction comes from the function (add_time_dimension) in ray.rllib.policy.rnn_sequencing.py. I understand that this structure is used since it is compatible when no RNN is used.

        padded_inputs = torch.as_tensor(padded_inputs)
        padded_batch_size = padded_inputs.shape[0]

        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = seq_lens.shape[0]
        time_size = padded_batch_size // new_batch_size
        batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
        padded_outputs = padded_inputs.view(batch_major_shape)

        if time_major:
            # Swap the batch and time dimensions
            padded_outputs = padded_outputs.transpose(0, 1)
        return padded_outputs

I would like to kindly confirm that my deduction and understanding are correct. Because if they are correct, then there is something wrong with my code, in which I get [Sequence_length, 1] structure.

Hi, any update on this problem? I believe I am facing a related problem

Hi, sorry for the late update.

It seems that everything works fine now if I follow the following guide:
Models, Preprocessors, and Action Distributions — Ray 2.39.0

It’s worth mentioning that, according to the doc:

Note that the inputs arg entering forward_rnn is already a time-ranked single tensor (not an input_dict !) with shape (B x T x ...)

AND

If you further want to customize and need more direct access to the complete (non time-ranked) input_dict , you can also override your Model’s forward method directly (as you would do with a non-RNN ModelV2).

This example provided by Ray could be quite helpful: ray/rllib/examples/_old_api_stack/models/rnn_model.py at master · ray-project/ray