I checked over all github projects that I found and at this forum, checked google search ‘python ray “flatten_space”’, but there is no place someone posted simple answer how to flatten layers when making action masking. Can someone help please? I separated into 3 related questions:
1)
Instead of using flatten in ActionMaskModel(TFModelV2)
, can I simply move "obs1"
, "obs2"
outside "actual_obs"
?
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(3,)),
"actual_obs": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
"obs1": Discrete(10),
"obs2": Discrete(10),
# ...
})
instead:
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(3,)),
"actual_obs": Dict({
"obs_box": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
"obs1": Discrete(10),
"obs2": Discrete(10),
# ...
})
})
Does it make sense? Will my network work properly? I checked it and there are no errors during execution.
2)
I read docs Variable-length / Parametric Action Spaces. What’s difference between ParametricActionsModel(TFModelV2)
and ActionMaskModel(TFModelV2)
? If I use PPOTrainer() and need simply masking actions.
3)
I copy pasted from ActionMaskModel(TFModelV2)
and TorchActionMaskModel(TorchModelV2, nn.Module)
from rllib/examples/models/action_mask_model.py
for simple action masking PPO
. Dict observation is:
self.observation_space= Dict({
"action_mask": Box(0, 1, shape=(4,)),
"actual_obs": Dict({
"obs1": Discrete(5),
"obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
}),
})
I guess I need to flatten this in order to make correct network model, so I used gym.spaces.utils.flatten_space
, but get error ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
. Is it really keras
error? I don’t have observations
key in env init()
. How to fix that?
framework="tf"
error:
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21627) self._build_policy_map(
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21627) self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21627) self[policy_id] = class_(
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21627) DynamicTFPolicy.__init__(
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21627) dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627) File "delete1.py", line 54, in forward
(RolloutWorker pid=21627) logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21627) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21627) model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21627) input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21627) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21627) raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21627) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
(RolloutWorker pid=21632) 2022-05-04 21:08:16,328 ERROR worker.py:430 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=21632, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x1750c6f10>)
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21632) self._build_policy_map(
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21632) self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 133, in create_policy
(RolloutWorker pid=21632) self[policy_id] = class_(
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 238, in __init__
(RolloutWorker pid=21632) DynamicTFPolicy.__init__(
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 333, in __init__
(RolloutWorker pid=21632) dist_inputs, self._state_out = self.model(
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632) File "delete1.py", line 54, in forward
(RolloutWorker pid=21632) logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21632) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/tf/fcnet.py", line 128, in forward
(RolloutWorker pid=21632) model_out, self._value_out = self.base_model(input_dict["obs_flat"])
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 739, in __call__
(RolloutWorker pid=21632) input_spec.assert_input_compatibility(self.input_spec, inputs,
(RolloutWorker pid=21632) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility
(RolloutWorker pid=21632) raise ValueError(f'Missing data for input "{name}". '
(RolloutWorker pid=21632) ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
framework="torch"
error:
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 586, in __init__
(RolloutWorker pid=21567) self._build_policy_map(
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1577, in _build_policy_map
(RolloutWorker pid=21567) self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
(RolloutWorker pid=21567) self[policy_id] = class_(observation_space, action_space,
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 50, in __init__
(RolloutWorker pid=21567) self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 832, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=21567) self.compute_actions_from_input_dict(
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
(RolloutWorker pid=21567) return self._compute_action_helper(input_dict, state_batches,
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 28, in wrapper
(RolloutWorker pid=21567) raise e
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(RolloutWorker pid=21567) return func(self, *a, **k)
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 934, in _compute_action_helper
(RolloutWorker pid=21567) dist_inputs, state_out = self.model(input_dict, state_batches,
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567) File "delete1.py", line 103, in forward
(RolloutWorker pid=21567) logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
(RolloutWorker pid=21567) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=21567) File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 122, in forward
(RolloutWorker pid=21567) obs = input_dict["obs_flat"].float()
(RolloutWorker pid=21567) AttributeError: 'collections.OrderedDict' object has no attribute 'float'
python: 3.8.12
ray: 1.11.0
tensorflow: 2.7.0
torch: 1.10.2
OS: Mac 11.6
Reproduction script
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.registry import register_env
import gym
from gym.spaces import Box, Dict, Discrete
from gym.spaces import utils
# from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class ActionMaskModel(TFModelV2):
def __init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert (
isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "actual_obs" in orig_space.spaces
)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self.internal_model = FullyConnectedNetwork(
utils.flatten_space(orig_space["actual_obs"]),
action_space,
num_outputs,
model_config,
name + "_internal",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
return logits + inf_mask, state
def value_function(self):
return self.internal_model.value_function()
class TorchActionMaskModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert (
isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "actual_obs" in orig_space.spaces
)
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
)
nn.Module.__init__(self)
self.internal_model = TorchFC(
utils.flatten_space(orig_space["actual_obs"]),
action_space,
num_outputs,
model_config,
name + "_internal",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["actual_obs"]})
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
# Return masked logits.
return logits + inf_mask, state
def value_function(self):
return self.internal_model.value_function()
class MyEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def __init__(self):
super(MyEnv, self).__init__()
self.action_space = Discrete(4)
self.observation_space_dict = Dict({
"action_mask": Box(0, 1, shape=(4,)),
"actual_obs": Dict({
"obs1": Discrete(5),
"obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
}),
})
self.observation_space = self.observation_space_dict
#self.observation_space = utils.flatten_space(self.observation_space_dict)
def reset(self):
return self._make_obs()
def step(self, action):
return self._make_obs(), 0, False, {}
def _make_obs(self):
obs_dict = {
"action_mask": np.array([1.0] * 4),
"actual_obs": {
"obs1": 1,
"obs2": np.zeros((10, 10)),
},
}
return obs_dict
# return utils.flatten(self.observation_space_dict, obs_dict)
def main ():
ray.init()
select_env = "env-v1"
register_env(select_env, lambda config: MyEnv())
framework = 'tf'
config = ppo.DEFAULT_CONFIG.copy()
config.update({
"env": select_env,
"framework": framework,
"log_level": 'DEBUG',
"model": {
"custom_model": ActionMaskModel if framework != "torch" else TorchActionMaskModel,
},
})
agent = ppo.PPOTrainer(config, env=select_env)
agent.train()
if __name__ == "__main__":
main()