How to disable flattened Dict or Tuple observation in ComplexInputNetwork

I was doing the “A multi-input capable model for Tuple observation spaces (for PPO)” example and encountered a problem.
The observation in the environment is initialized to

from gym.spaces import Tuple as tp
obs_spaces = {
            "actor1": tp((
                        Box(float("-inf"), float("inf"), (84, 84, 1)),
                        Discrete(7),
                        )),
        }

The return value of the observation is

#image: (84,84,1)
#a:int
return {"actor1": (image,a)}

The function I use is ray/complex_input_net.py at master · ray-project/ray · GitHub.
Training file

class ComplexInputNetwork(TorchModelV2, nn.Module):
def init(self, obs_space, action_space, num_outputs, model_config, name):
def forward(self, input_dict, state, seq_lens):
def value_function(self):
. . . . . . . .
ModelCatalog.register_custom_model(“testmodel”, ComplexInputNetwork)
“model”: {
“custom_model”: “testmodel”,
# Extra kwargs to be passed to your model’s c’tor.
“custom_model_config”: {},
},
Policy is set to PPO
The result of the operation is

File “C:\ProgramData\Anaconda3\Lib\site-packages\ray\rllib\examples\b_comlex.py”, line 250, in forward
cnn_out, _ = self.cnns[i]({“obs”: component})
File “C:\ProgramData\Anaconda3\lib\site-packages\ray\rllib\models\modelv2.py”, line 213, in call
res = self.forward(restored, state or [], seq_lens)
File “C:\ProgramData\Anaconda3\lib\site-packages\ray\rllib\models\torch\visionnet.py”, line 192, in forward
self._features = self._features.permute(0, 3, 1, 2)
RuntimeError: number of dims don’t match in permute

I printed some of the results and found that obs was paved in ComplexInputNetwork.forward. Hasattr(obs_space, "original_space") is True in ComplexInputNetwork.init. But in forward(self, input_dict, state, seq_lens), it called ray/modelv2.py at master · ray-project/ray · GitHub was paved, resulting in progress

def forward(self, input_dict, state, seq_lens):
    # Push image observations through our CNNs.
    outs = []
    for i, component in enumerate(input_dict["obs"]):
        if i in self.cnns:
            cnn_out, _ = self.cnns[i]({"obs": component})
            outs.append(cnn_out)
        elif i in self.one_hot:
            if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                outs.append(
                    one_hot(component, self.original_space.spaces[i]))
            else:
                outs.append(component)
        else:
            outs.append(torch.reshape(component, [-1, self.flatten[i]]))
    # Concat all outputs and the non-image inputs.
    out = torch.cat(outs, dim=1)
    # Push through (optional) FC-stack (this may be an empty stack).
    out, _ = self.post_fc_stack({"obs": out}, [], None)

, The input dimension is
torch.Size([7057])
#84841+1=7057
How can I set up so that the input is not paved and kept as original ([84,84,1],1)
Thank you very much

CC @sven1977 can you help with this?

Thank you very much for your help, I tried to update the version, but the error is still not resolved. The error becomes

Expected flattened obs shape of […, 7063], got torch.Size([32, 7057])

which means I return

#84×84×1+1=7057
{“actor1”: (image,a)}

the dimension of a is 1 , But the dimension of

#84×84×1+7=7063
Tuple(Box(float("-inf"), float(“inf”), (84, 84, 1)),Discrete(7))

is7063 , so the error is now Expected flattened obs shape of […, 7063], got torch.Size([32, 7057]). After I changed Discrete to BOX, there was no error for the time being, but I still haven’t found a function for how to convert Discrete to one-hot. I am going to study the program source code carefully

Discrete spaces are automatically converted to one-hot in rllib.

Thank you for your reply, I will learn the program source code again :smiley:

I got the same error, its because of the discrete obs_space. I changed it to box aswell and that helped. The reason is that discrete(7) will be preprocessed with the “OneHotPreprocessor” to 7 values, and a box value is just flattend (if its more than 1-dimensional) and stays 1.

For @sven1977 or someone from ray:
The problem is in modelv2.py in “def _unpack_obs(line 385)”, bc the given “obs” variable doesn’t contain the right preprocessed obs_space. As i saw it, during my debugging, discrete values were only counted by 1(in catalogy.py, “get_action_shape”, used in “get_action_placeholder” line 303f) not by their one-hot encoded number.