AttentionNet with action masking

Hi all,
I am able to use AttentionNet with my own environment. It works well. I can also use fc net with action masking, but when I want to use AttentionNet with action masking, it fails…
Here is my code. I don’t know if I have to specify state and seq_lens in the forward function below.

class ParametricActionsAttentionModel(TFModelV2):
def init(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(10102,), **kwargs):
super(ParametricActionsAttentionModel, self).init(
obs_space, action_space, action_space.n, model_config, name)
self.seq_length = model_config[“max_seq_len”]
self.attention_dim = model_config[“attention_dim”]
self.train_batch_size = model_config[“train_batch_size”]
model_config.pop(“custom_model”)
self.model = ModelCatalog.get_model_v2(
obs_space=Box(low=-3.0, high=3.0, shape=true_obs_shape),
action_space=action_space,
num_outputs=action_space.n,
model_config=model_config,
framework=‘tf2’)

def forward(self, input_dict, state, seq_lens):
action_mask = input_dict[“obs”][“action_mask”]
action_logits, _ = self.model({"obs:input_dict[“obs”][“state”]}, state, seq_lens)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state

def value_function(self):
return self.model.value_function()

And the error I got:
python3.8/site-packages/ray/rllib/models/modelv2.py", line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 444, in forward
(pid=1300201) wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(pid=1300201) File “/git/rl/ray/tf_models.py”, line 239, in forward
(pid=1300201) action_logits, _ = self.model(input_dict, state, seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/modelv2.py”, line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 482, in forward
(pid=1300201) self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/modelv2.py”, line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 320, in forward
(pid=1300201) observations = tf.reshape(observations,
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py”, line 206, in wrapper
(pid=1300201) return target(*args, **kwargs)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py”, line 195, in reshape
(pid=1300201) result = gen_array_ops.reshape(tensor, shape, name)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py”, line 8397, in reshape
(pid=1300201) _, _, _op, _outputs = _op_def_library._apply_op_helper(
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py”, line 748, in _apply_op_helper
(pid=1300201) op = g._create_op_internal(op_type_name, inputs, dtypes=None,
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/ops.py”, line 3557, in _create_op_internal
(pid=1300201) ret = Operation(
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/ops.py”, line 2045, in init
(pid=1300201) self._traceback = tf_stack.extract_stack_for_node(self._c_op)

Hello! I would like to know if action masking is something that is available off the shelf with RLlib?

Hi @yinyee,

Action making is supported by rllib but you have to write the model to do the action making your self.

Here is an example:

Hi @Adri31,

Is there more to the stack trace? It looks like the actual error message is missing.

If you wrap your code and error message in three `s then it will preserve the formatting and be easier to read.

There is an issue here:
action_logits, _ = self.model({"obs:input_dict[“obs”][“state”]}, state, seq_lens)"

You are not returning the update memory states. It will always use the initial state.

This is probably not the source of your error though.

Do you have a full reproduction script you could share in Google collab?

Thank @mannyv for your help. I tried to reproduce the code with a well-known environment => Google Colab

With my environment, when “use_attention” is False, it works with action masking. When “use_attention” is False, it fails with a weird error: "Received a label of 100 which is outside the valid range [0, 100) in tf.actiondist.py:70.

In the colab code, it is very strange. I tried several environments and Ray does not take into account stop iteration = 10, the custom model and the model itself… Something wrong in my code…