How to flatten space when action masking?

I checked over all github projects that I found and at this forum, checked google search ‘python ray “flatten_space”’, but there is no place someone posted simple answer how to flatten layers when making action masking. Can someone help please? I separated into 3 related questions:

1)
Instead of using flatten in ActionMaskModel(TFModelV2), can I simply move "obs1" , "obs2" outside "actual_obs" ?

self.observation_space = Dict({
    "action_mask": Box(0, 1, shape=(3,)),
    "actual_obs": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
    "obs1": Discrete(10),
    "obs2": Discrete(10),
    # ...
})

instead:

self.observation_space = Dict({
    "action_mask": Box(0, 1, shape=(3,)),
    "actual_obs": Dict({
        "obs_box": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
        "obs1": Discrete(10),
        "obs2": Discrete(10),
        # ...
    })
})

Does it make sense? Will my network work properly? I checked it and there are no errors during execution.

2)
I read docs Variable-length / Parametric Action Spaces. What’s difference between ParametricActionsModel(TFModelV2) and ActionMaskModel(TFModelV2) ? If I use PPOTrainer() and need simply masking actions.

3)
I copy pasted from ActionMaskModel(TFModelV2) and TorchActionMaskModel(TorchModelV2, nn.Module) from rllib/examples/models/action_mask_model.py for simple action masking PPO. Dict observation is:

self.observation_space= Dict({
    "action_mask": Box(0, 1, shape=(4,)),
    "actual_obs": Dict({
        "obs1": Discrete(5),
        "obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
    }),
})

I guess I need to flatten this in order to make correct network model, so I used gym.spaces.utils.flatten_space, but get error ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']. Is it really keras error? I don’t have observations key in env init(). How to fix that?

framework="tf" error:

(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21627)     self._build_policy_map(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21627)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21627)     self[policy_id] = class_(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21627)     DynamicTFPolicy.__init__(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21627)     dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627)   File "delete1.py", line 54, in forward
(RolloutWorker pid=21627)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21627)     model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21627)     input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21627)     raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21627) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
(RolloutWorker pid=21632) 2022-05-04 21:08:16,328	ERROR worker.py:430 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=21632, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x1750c6f10>)
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21632)     self._build_policy_map(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21632)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21632)     self[policy_id] = class_(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21632)     DynamicTFPolicy.__init__(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21632)     dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632)   File "delete1.py", line 54, in forward
(RolloutWorker pid=21632)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21632)     model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21632)     input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21632)     raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21632) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']

framework="torch" error:

(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21567)     self._build_policy_map(
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21567)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
(RolloutWorker pid=21567)     self[policy_id] = class_(observation_space, action_space,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 50, in __init__
(RolloutWorker pid=21567)     self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 832, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=21567)     self.compute_actions_from_input_dict(
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
(RolloutWorker pid=21567)     return self._compute_action_helper(input_dict, state_batches,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 28, in wrapper
(RolloutWorker pid=21567)     raise e
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(RolloutWorker pid=21567)     return func(self, *a, **k)
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 934, in _compute_action_helper
(RolloutWorker pid=21567)     dist_inputs, state_out = self.model(input_dict, state_batches,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567)   File "delete1.py", line 103, in forward
(RolloutWorker pid=21567)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 122, in forward
(RolloutWorker pid=21567)     obs = input_dict["obs_flat"].float()
(RolloutWorker pid=21567) AttributeError: 'collections.OrderedDict' object has no attribute 'float'

python: 3.8.12
ray: 1.11.0
tensorflow: 2.7.0
torch: 1.10.2
OS: Mac 11.6

Reproduction script

import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.registry import register_env
import gym
from gym.spaces import Box, Dict, Discrete

from gym.spaces import utils
# from ray.rllib.utils.spaces.space_utils import flatten_space

from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

class ActionMaskModel(TFModelV2):
    def __init__(
        self, obs_space, action_space, num_outputs, model_config, name, **kwargs
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
            isinstance(orig_space, Dict)
            and "action_mask" in orig_space.spaces
            and "actual_obs" in orig_space.spaces
        )
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        self.internal_model = FullyConnectedNetwork(
            utils.flatten_space(orig_space["actual_obs"]),
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)

        return logits + inf_mask, state

    def value_function(self):
        return self.internal_model.value_function()

class TorchActionMaskModel(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
            isinstance(orig_space, Dict)
            and "action_mask" in orig_space.spaces
            and "actual_obs" in orig_space.spaces
        )

        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )
        nn.Module.__init__(self)

        self.internal_model = TorchFC(
            utils.flatten_space(orig_space["actual_obs"]),
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)

        # Return masked logits.
        return logits + inf_mask, state

    def value_function(self):
        return self.internal_model.value_function()


class MyEnv(gym.Env):
    metadata = {"render.modes": ["human"]}
    def __init__(self):
        super(MyEnv, self).__init__()

        self.action_space = Discrete(4)
        self.observation_space_dict = Dict({
            "action_mask": Box(0, 1, shape=(4,)),
            "actual_obs": Dict({
                "obs1": Discrete(5),
                "obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
            }),
        })

        self.observation_space = self.observation_space_dict
        #self.observation_space = utils.flatten_space(self.observation_space_dict)
    
    def reset(self):
        return self._make_obs()
    
    def step(self, action):
        return self._make_obs(), 0, False, {}

    def _make_obs(self):
        obs_dict = {
            "action_mask": np.array([1.0] * 4),
            "actual_obs": {
                "obs1": 1,
                "obs2": np.zeros((10, 10)),
            },
        }
        return obs_dict
        # return utils.flatten(self.observation_space_dict, obs_dict)

def main ():
    ray.init()
    select_env = "env-v1"
    register_env(select_env, lambda config: MyEnv())
    framework = 'tf'
    config = ppo.DEFAULT_CONFIG.copy()
    config.update({
        "env": select_env,
        "framework": framework,
        "log_level": 'DEBUG',
        "model": {
            "custom_model": ActionMaskModel if framework != "torch" else TorchActionMaskModel,
        },
    })
    agent = ppo.PPOTrainer(config, env=select_env)
    agent.train()

if __name__ == "__main__":
    main()

@arturn @sven1977

Hi @sirjay, to shed a little light on how the space gets flatten, take a look at the flatten_space() function.

@Lars_Simon_Zehnder thank you for reply. If I use flatten_space from ray.rllib.utils.spaces.space_utils, I get list object and get this error: AttributeError: 'list' object has no attribute 'shape' from class ActionMaskModel(TFModelV2).
Yes, this function returns list, but what to do with list object after applying flatten? How to convert it to box (if I need)?

@sirjay I though your question was how to flatten the space?

There is a function in modelv2.py that can restore original dimensions.