How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hi, I am trying to implement a simple discrete action masking in rllib. The idea is to have the entire action space (182) available at the beginning of an episode and once the action has been chosen it cannot be picked again, so it’s an ever decreasing action space. I have read through the related examples on GitHub including action_masking.py and action_mask_env.py, as well as some previous posts on the forum so I do have some vague idea. However I keep getting the same error message when I am initiating the algo.
Also is it mandatory to set the framework to “framework”: “tf2” when calling the custom model? You can find snippets of the code below. Thank you very much!
class ActionMaskEnv(gym.Env):
def __init__(self, env_config):
super(ActionMaskEnv, self).__init__()
# Define observation space
self.observation_space = Dict({
"observations": Box(low=0, high=1, shape=(910,), dtype=np.int32),
"action_mask": Box(0, 1, shape=(182,)),
})
# Define action space
self.action_space = Discrete(182)
# action mask (initially all actions possible)
# updates mask at every step until entire array is zeros
self.mask = np.ones(182, dtype=np.float32)
.......
class ActionMaskModel(TFModelV2):
"""Model that handles simple discrete action masking.
This assumes the outputs are logits for a single Categorical action dist.
Getting this to work with a more complex output (e.g., if the action space
is a tuple of several distributions) is also possible but left as an
exercise to the reader.
"""
def __init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert (
isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "observations" in orig_space.spaces
)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self.internal_model = FullyConnectedNetwork(
orig_space["observations"],
action_space,
num_outputs,
model_config,
name + "_internal",
)
print(orig_space["observations"])
# disable action masking --> will likely lead to invalid actions
self.no_masking = model_config["custom_model_config"].get("no_masking", False)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
print(input_dict)
action_mask = input_dict["obs"]["action_mask"]
print(input_dict["obs"])
print(action_mask)
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]})
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
masked_logits = logits + inf_mask
# Return masked logits.
return masked_logits, state
def value_function(self):
return self.internal_model.value_function()
ModelCatalog.register_custom_model("action_mask_model", ActionMaskModel)
algo = dqn.DQN(env=ActionMaskEnv, config={
"rollout_fragment_length": 100,
"env_config": {},
"hiddens": [],
"model": {
"custom_model": "action_mask_model",
},
"train_batch_size": 1000,
"framework": "tf2",
"horizon": 182,
"eager_tracing": True,
"min_train_timesteps_per_iteration": 100,
"min_sample_timesteps_per_iteration": 2000,
})