How to flatten space when action masking?

I checked over all github projects that I found and at this forum, checked google search ‘python ray “flatten_space”’, but there is no place someone posted simple answer how to flatten layers when making action masking. Can someone help please? I separated into 3 related questions:

1)
Instead of using flatten in ActionMaskModel(TFModelV2), can I simply move "obs1" , "obs2" outside "actual_obs" ?

self.observation_space = Dict({
    "action_mask": Box(0, 1, shape=(3,)),
    "actual_obs": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
    "obs1": Discrete(10),
    "obs2": Discrete(10),
    # ...
})

instead:

self.observation_space = Dict({
    "action_mask": Box(0, 1, shape=(3,)),
    "actual_obs": Dict({
        "obs_box": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
        "obs1": Discrete(10),
        "obs2": Discrete(10),
        # ...
    })
})

Does it make sense? Will my network work properly? I checked it and there are no errors during execution.

2)
I read docs Variable-length / Parametric Action Spaces. What’s difference between ParametricActionsModel(TFModelV2) and ActionMaskModel(TFModelV2) ? If I use PPOTrainer() and need simply masking actions.

3)
I copy pasted from ActionMaskModel(TFModelV2) and TorchActionMaskModel(TorchModelV2, nn.Module) from rllib/examples/models/action_mask_model.py for simple action masking PPO. Dict observation is:

self.observation_space= Dict({
    "action_mask": Box(0, 1, shape=(4,)),
    "actual_obs": Dict({
        "obs1": Discrete(5),
        "obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
    }),
})

I guess I need to flatten this in order to make correct network model, so I used gym.spaces.utils.flatten_space, but get error ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']. Is it really keras error? I don’t have observations key in env init(). How to fix that?

framework="tf" error:

(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21627)     self._build_policy_map(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21627)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21627)     self[policy_id] = class_(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21627)     DynamicTFPolicy.__init__(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21627)     dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627)   File "delete1.py", line 54, in forward
(RolloutWorker pid=21627)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21627)     model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21627)     input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21627)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21627)     raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21627) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
(RolloutWorker pid=21632) 2022-05-04 21:08:16,328	ERROR worker.py:430 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=21632, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x1750c6f10>)
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21632)     self._build_policy_map(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21632)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21632)     self[policy_id] = class_(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21632)     DynamicTFPolicy.__init__(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21632)     dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632)   File "delete1.py", line 54, in forward
(RolloutWorker pid=21632)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21632)     model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21632)     input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21632)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21632)     raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21632) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']

framework="torch" error:

(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21567)     self._build_policy_map(
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21567)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
(RolloutWorker pid=21567)     self[policy_id] = class_(observation_space, action_space,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 50, in __init__
(RolloutWorker pid=21567)     self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 832, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=21567)     self.compute_actions_from_input_dict(
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
(RolloutWorker pid=21567)     return self._compute_action_helper(input_dict, state_batches,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 28, in wrapper
(RolloutWorker pid=21567)     raise e
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(RolloutWorker pid=21567)     return func(self, *a, **k)
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 934, in _compute_action_helper
(RolloutWorker pid=21567)     dist_inputs, state_out = self.model(input_dict, state_batches,
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567)   File "delete1.py", line 103, in forward
(RolloutWorker pid=21567)     logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567)   File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 122, in forward
(RolloutWorker pid=21567)     obs = input_dict["obs_flat"].float()
(RolloutWorker pid=21567) AttributeError: 'collections.OrderedDict' object has no attribute 'float'

python: 3.8.12
ray: 1.11.0
tensorflow: 2.7.0
torch: 1.10.2
OS: Mac 11.6

Reproduction script

import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.registry import register_env
import gym
from gym.spaces import Box, Dict, Discrete

from gym.spaces import utils
# from ray.rllib.utils.spaces.space_utils import flatten_space

from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

class ActionMaskModel(TFModelV2):
    def __init__(
        self, obs_space, action_space, num_outputs, model_config, name, **kwargs
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
            isinstance(orig_space, Dict)
            and "action_mask" in orig_space.spaces
            and "actual_obs" in orig_space.spaces
        )
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        self.internal_model = FullyConnectedNetwork(
            utils.flatten_space(orig_space["actual_obs"]),
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)

        return logits + inf_mask, state

    def value_function(self):
        return self.internal_model.value_function()

class TorchActionMaskModel(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
            isinstance(orig_space, Dict)
            and "action_mask" in orig_space.spaces
            and "actual_obs" in orig_space.spaces
        )

        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )
        nn.Module.__init__(self)

        self.internal_model = TorchFC(
            utils.flatten_space(orig_space["actual_obs"]),
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)

        # Return masked logits.
        return logits + inf_mask, state

    def value_function(self):
        return self.internal_model.value_function()


class MyEnv(gym.Env):
    metadata = {"render.modes": ["human"]}
    def __init__(self):
        super(MyEnv, self).__init__()

        self.action_space = Discrete(4)
        self.observation_space_dict = Dict({
            "action_mask": Box(0, 1, shape=(4,)),
            "actual_obs": Dict({
                "obs1": Discrete(5),
                "obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
            }),
        })

        self.observation_space = self.observation_space_dict
        #self.observation_space = utils.flatten_space(self.observation_space_dict)
    
    def reset(self):
        return self._make_obs()
    
    def step(self, action):
        return self._make_obs(), 0, False, {}

    def _make_obs(self):
        obs_dict = {
            "action_mask": np.array([1.0] * 4),
            "actual_obs": {
                "obs1": 1,
                "obs2": np.zeros((10, 10)),
            },
        }
        return obs_dict
        # return utils.flatten(self.observation_space_dict, obs_dict)

def main ():
    ray.init()
    select_env = "env-v1"
    register_env(select_env, lambda config: MyEnv())
    framework = 'tf'
    config = ppo.DEFAULT_CONFIG.copy()
    config.update({
        "env": select_env,
        "framework": framework,
        "log_level": 'DEBUG',
        "model": {
            "custom_model": ActionMaskModel if framework != "torch" else TorchActionMaskModel,
        },
    })
    agent = ppo.PPOTrainer(config, env=select_env)
    agent.train()

if __name__ == "__main__":
    main()

@arturn @sven1977

Hi @sirjay, to shed a little light on how the space gets flatten, take a look at the flatten_space() function.

@Lars_Simon_Zehnder thank you for reply. If I use flatten_space from ray.rllib.utils.spaces.space_utils, I get list object and get this error: AttributeError: 'list' object has no attribute 'shape' from class ActionMaskModel(TFModelV2).
Yes, this function returns list, but what to do with list object after applying flatten? How to convert it to box (if I need)?

@sirjay I though your question was how to flatten the space?

There is a function in modelv2.py that can restore original dimensions.

@Lars_Simon_Zehnder a little clarification. I am facing the same problem.

It seems to me that the original action masking example only works with flat observation spaces. As stated above it does not work with dicts.

In order to instantiate the internal model (for example FullyConnectedNetwork) the observation space should be a Box space. The shape would be the number of features in the Dict space. I guess the values should be between -1 and 1.

The trouble with flatten_space is that it converts any space to a List of Spaces, effectively removing Tuple and Dict structures. Although this is fairly close to Box, there are a few steps missing.

So the question would be:

Given any observation space that contains Dict and Tuple, how can it be mapped to a single Box space.

Thanks for your time!

@aiguru To clarify here a little about how RLlib treats Dict/Tuple observation spaces. In the model_catalog.py file is all the logic to decide on which model class to semi-automatically use for the ‘Policy’. In line 962 thereof it is decided that if the observation space is Dict/Tuple the ComplexInputNetwork should be used that can handle more complex nested spaces. Therefore the action masking example does not work with a complex nested input space when trying to use the FullyConnectedNetwork (which is hard-coded/user-chosen in the example).

In regard to your statement about the flatten_space function. You are right, this is not comparable to a Box spaces (I do not understand your comment “[…] there are a few steps missing” though, maybe you can elaborate a little more about this, if you think flatten_space() should be modified) it is a list of spaces and it is so for a reason: The ComplexInputNetwork constructs for each space in the list a corresponding input layer and concatenates the layers in a final step together to keep the expressiveness of the components of the original nested space.

To summarize: Compley nested observation spaces need possibly the ComplexInputNetwork and with that network the action masking should be put to work.

Let me know, if this answers your questions/solves your concerns.

1 Like