Hi all,
I am able to use AttentionNet with my own environment. It works well. I can also use fc net with action masking, but when I want to use AttentionNet with action masking, it fails…
Here is my code. I don’t know if I have to specify state and seq_lens in the forward function below.
class ParametricActionsAttentionModel(TFModelV2):
def init(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(10102,), **kwargs):
super(ParametricActionsAttentionModel, self).init(
obs_space, action_space, action_space.n, model_config, name)
self.seq_length = model_config[“max_seq_len”]
self.attention_dim = model_config[“attention_dim”]
self.train_batch_size = model_config[“train_batch_size”]
model_config.pop(“custom_model”)
self.model = ModelCatalog.get_model_v2(
obs_space=Box(low=-3.0, high=3.0, shape=true_obs_shape),
action_space=action_space,
num_outputs=action_space.n,
model_config=model_config,
framework=‘tf2’)
def forward(self, input_dict, state, seq_lens):
action_mask = input_dict[“obs”][“action_mask”]
action_logits, _ = self.model({"obs:input_dict[“obs”][“state”]}, state, seq_lens)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
def value_function(self):
return self.model.value_function()
And the error I got:
python3.8/site-packages/ray/rllib/models/modelv2.py", line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 444, in forward
(pid=1300201) wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(pid=1300201) File “/git/rl/ray/tf_models.py”, line 239, in forward
(pid=1300201) action_logits, _ = self.model(input_dict, state, seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/modelv2.py”, line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 482, in forward
(pid=1300201) self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/modelv2.py”, line 230, in call
(pid=1300201) res = self.forward(restored, state or [], seq_lens)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/ray/rllib/models/tf/attention_net.py”, line 320, in forward
(pid=1300201) observations = tf.reshape(observations,
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py”, line 206, in wrapper
(pid=1300201) return target(*args, **kwargs)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py”, line 195, in reshape
(pid=1300201) result = gen_array_ops.reshape(tensor, shape, name)
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py”, line 8397, in reshape
(pid=1300201) _, _, _op, _outputs = _op_def_library._apply_op_helper(
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py”, line 748, in _apply_op_helper
(pid=1300201) op = g._create_op_internal(op_type_name, inputs, dtypes=None,
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/ops.py”, line 3557, in _create_op_internal
(pid=1300201) ret = Operation(
(pid=1300201) File “/.virtualenvs/lib/python3.8/site-packages/tensorflow/python/framework/ops.py”, line 2045, in init
(pid=1300201) self._traceback = tf_stack.extract_stack_for_node(self._c_op)