I want to reproduce the code of this blog, but I got an error. Action Masking with RLlib.
Here is the code script. use pip install or_gym
first
from or_gym.utils import create_env
from gym import spaces
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models import ModelCatalog
from ray import tune
from ray.rllib import agents
import ray
import or_gym
import numpy as np
env = or_gym.make('Knapsack-v0')
print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))
env_config = {'N': 5,
'max_weight': 15,
'item_weights': np.array([1, 12, 2, 1, 4]),
'item_values': np.array([2, 4, 2, 1, 10]),
'mask': True}
env = or_gym.make('Knapsack-v0', env_config=env_config)
print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))
tf = try_import_tf()
# tf.compat.v1.disable_eager_execution()
class KP0ActionMaskModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs,
model_config, name, true_obs_shape=(11,),
action_embed_size=5, *args, **kwargs):
super(KP0ActionMaskModel, self).__init__(obs_space,
action_space, num_outputs, model_config, name,
*args, **kwargs)
self.action_embed_model = FullyConnectedNetwork(
spaces.Box(0, 1, shape=true_obs_shape),
action_space, action_embed_size,
model_config, name + "_action_embedding")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
action_embedding, _ = self.action_embed_model({
"obs": input_dict["obs"]["state"]})
intent_vector = tf.expand_dims(action_embedding, 1)
action_logits = tf.math.reduce_sum(avail_actions * intent_vector,
axis=1)
inf_mask = tf.math.maximum(tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
ModelCatalog.register_custom_model('kp_mask', KP0ActionMaskModel)
def register_env(env_name, env_config={}):
env = create_env(env_name)
tune.register_env(env_name, lambda env_name: env(
env_name, env_config=env_config))
register_env('Knapsack-v0', env_config=env_config)
ray.init(ignore_reinit_error=True)
trainer_config = {
"model": {
"custom_model": "kp_mask"
},
"env_config": env_config
}
trainer = agents.ppo.PPOTrainer(env='Knapsack-v0', config=trainer_config)
env = trainer.env_creator('Knapsack-v0')
state = env.state
state['action_mask'][0] = 0
actions = np.array([trainer.compute_action(state) for i in range(10)])
print(actions)
This script works fine in Ray0.8.7, but in Ray1.0.1 rasie Error. Because trainer.compute_action()
can’t deal with dict type input
Traceback (most recent call last):
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 60, in check_shape
if not self._obs_space.contains(observation):
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/gym/spaces/box.py", line 128, in contains
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
AttributeError: 'dict' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <module>
actions = np.array([trainer.compute_action(state) for i in range(10)])
File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <listcomp>
actions = np.array([trainer.compute_action(state) for i in range(10)])
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 819, in compute_action
preprocessed = self.workers.local_worker().preprocessors[
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 166, in transform
self.check_shape(observation)
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 66, in check_shape
raise ValueError(
ValueError: ('Observation for a Box/MultiBinary/MultiDiscrete space should be an np.array, not a Python list.', {'action_mask': array([0, 1, 1, 1, 1]), 'avail_actions': array([1., 1., 1., 1., 1.]), 'state': array([ 1, 12, 2, 1, 4, 2, 4, 2, 1, 10, 0])})