@arturn thank you for answer!!
You can not use DistributionalQTFModel with the PPO trainer, because it was built for Q-Learning algorithms, which PPO is not.
- I made a type in question, what’s difference between
ParametricActionsModel(TFModelV2)
andActionMaskModel(TFModelV2)
? If I use PPOTrainer() and need simply masking actions.
The error you are referring to stems from
obs_space.shape
being None. This is because you use a Dict obs_space. You can flatten it first to gain a space that has a proper shape.
- I made flatten in model and error says:
File "/miniforge3/envs/rl/lib/python3.8/site-packages/keras/engine/input_spec.py", line 182, in assert_input_compatibility raise ValueError(f'Missing data for input "{name}". '
ValueError: Missing data for input "observations". You passed a data dictionary with keys ['obs1', 'obs2']. Expected the following keys: ['observations']
How to go further? It seems something is missing, I searched over all github and did not find examples how to manage this.
from gym.spaces import utils
class ActionMaskModel(TFModelV2):
...
assert (
isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "actual_obs" in orig_space.spaces
)
self.internal_model = FullyConnectedNetwork(
utils.flatten_space(orig_space["actual_obs"]),
action_space,
num_outputs,
model_config,
name + "_internal",
)
class MyEnv(gym.Env):
def __init__(self):
super(MyEnv, self).__init__()
self.observation_space_dict = Dict({
"action_mask": Box(0, 1, shape=(self.actions,)),
"actual_obs": Dict({
"obs1": Discrete(5),
"obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10), dtype=np.float32),
}),
})
- By the way, instead of using flatten, can I simply move
"obs1"
,"obs2"
outside"actual_obs"
?
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(3,)),
"actual_obs": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
"obs1": Discrete(10),
"obs2": Discrete(10),
# ...
})
instead:
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(3,)),
"actual_obs": Dict({
"obs_box": Box(low=-np.inf, high=np.inf, shape=(10, 10)),
"obs1": Discrete(10),
"obs2": Discrete(10),
# ...
})
})
Does it make sense? Will my network work properly? I checked it and there are no errors during execution.