Manual state-in inference does not work with MultiDiscrete environment

Please consider the following notebook:

this is a full version Stateless CartPole example with code taken from:

It has the previous version of the environment the one with all the code and not the derived class from CartPoleEnv.
This example works.
Now I have made the following small changes:

image

action space is MultiDiscrete instead of Discrete,
the action is handled with a simple “and” in the step method:

image

You can find the modified notebook at:

Of course this change does not make sense but it can be useful for testing, I think.
The training with tune works properly, apparently.
But when I am doing the manual state-in inference I am getting the following error:

IndexError                                Traceback (most recent call last)
/tmp/ipykernel_5980/2835217490.py in <module>
     29 
     30     while not done:
---> 31         a, state_out, _ = trainer.compute_single_action(obs, state, prev_action=prev_a, prev_reward=prev_r)
     32         obs, reward, done, _ = env.step(a)
     33         episode_reward += reward

/opt/conda/lib/python3.9/site-packages/ray/rllib/agents/trainer.py in compute_single_action(self, observation, state, prev_action, prev_reward, info, input_dict, policy_id, full_fetch, explore, timestep, episode, unsquash_action, clip_action, unsquash_actions, clip_actions, **kwargs)
   1130         # Individual args.
   1131         else:
-> 1132             action, state, extra = policy.compute_single_action(
   1133                 obs=observation,
   1134                 state=state,

/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/policy.py in compute_single_action(self, obs, state, prev_action, prev_reward, info, input_dict, episode, explore, timestep, **kwargs)
    241             episodes = [episode]
    242 
--> 243         out = self.compute_actions_from_input_dict(
    244             input_dict=SampleBatch(input_dict),
    245             episodes=episodes,

/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/torch_policy.py in compute_actions_from_input_dict(self, input_dict, explore, timestep, **kwargs)
    300                 if state_batches else None
    301 
--> 302             return self._compute_action_helper(input_dict, state_batches,
    303                                                seq_lens, explore, timestep)
    304 

/opt/conda/lib/python3.9/site-packages/ray/rllib/utils/threading.py in wrapper(self, *a, **k)
     19         try:
     20             with self._lock:
---> 21                 return func(self, *a, **k)
     22         except AttributeError as e:
     23             if "has no attribute '_lock'" in e.args[0]:

/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/torch_policy.py in _compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep)
    364             else:
    365                 dist_class = self.dist_class
--> 366                 dist_inputs, state_out = self.model(input_dict, state_batches,
    367                                                     seq_lens)
    368 

/opt/conda/lib/python3.9/site-packages/ray/rllib/models/modelv2.py in __call__(self, input_dict, state, seq_lens)
    241 
    242         with self.context():
--> 243             res = self.forward(restored, state or [], seq_lens)
    244 
    245         if ((not isinstance(res, list) and not isinstance(res, tuple))

/opt/conda/lib/python3.9/site-packages/ray/rllib/models/torch/recurrent_net.py in forward(self, input_dict, state, seq_lens)
    191         if self.model_config["lstm_use_prev_action"]:
    192             if isinstance(self.action_space, (Discrete, MultiDiscrete)):
--> 193                 prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(),
    194                                  self.action_space)
    195             else:

/opt/conda/lib/python3.9/site-packages/ray/rllib/utils/torch_ops.py in one_hot(x, space)
    186     elif isinstance(space, MultiDiscrete):
    187         return torch.cat(
--> 188             [
    189                 nn.functional.one_hot(x[:, i].long(), n)
    190                 for i, n in enumerate(space.nvec)

/opt/conda/lib/python3.9/site-packages/ray/rllib/utils/torch_ops.py in <listcomp>(.0)
    187         return torch.cat(
    188             [
--> 189                 nn.functional.one_hot(x[:, i].long(), n)
    190                 for i, n in enumerate(space.nvec)
    191             ],

IndexError: index 1 is out of bounds for dimension 1 with size 1

The manual state-in inference is the following:

I need your help!
Thanks.

Hi @mg64ve ,

without executing the code - I saw no raw mistake here. Also, the error occurs in a function operating one hot encoding for torch tensors in case the space is MultiDiscrete so that should work.

My suggestion is to set a break point at the line in the forward step of the model where one_hot() is used and print out the entry of input_dict[SampleBatch.PREV_ACTIONS] to check the dimensions of the processed action batch. I guess that somehow dimensions do not fit.

Also, what ray version are you working on?

Thanks @Lars_Simon_Zehnder , I have tested with both 1.8.0 and the last 2.0.0.dev0
I am not very confident how to setup breakpoints. My configuration consists in two docker container, one for each ray version with installation from pip. Then I am using jupyter notebooks for my tests.
What is your recommended configuration for debugging with breakpoints?

Hi @mg64ve ,

I use usually VSCode with a virtual environment and therein I install "ray[default]". Then I can simply debug the code and see what my numbers are doing.

Thanks @Lars_Simon_Zehnder , it seems now in version 2.0.0dev0 works.

1 Like