I want implement action masking for a simple environment with a dictionary observation space using the ApexDQN algorithm.
The input_dict given to the forward method of my custom model sometimes is a SampleBatch and sometimes just a dictionary (with the exact same info the sample batch would have).
Notice I have some (suboptimal) code that iterates over the rows of a sample batch and flattens them. This code does work for PPO, where the input_dict is always a SampleBatch.
What would be the best way to implement action masking for the apex DQN model?
import numpy as np
import ray
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
from ray.tune.registry import register_env
import gymnasium
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.spaces.utils import flatten_space, flatten
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
torch, nn = try_import_torch()
# copy pasted from rllib/examples/models/action_mask_model.py
class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert (isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "actual_obs" in orig_space.spaces)
self.orig_state_space = orig_space["actual_obs"]
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kwargs)
nn.Module.__init__(self)
self.internal_model = TorchFC(
flatten_space(orig_space["actual_obs"]),
action_space,
num_outputs,
model_config,
name + "_internal",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
data = []
for row in input_dict.rows(): # this code only works when input_dict is a SampleBatch
flattened_sample = flatten(self.orig_state_space,
row['obs']['actual_obs'])
data.append(flattened_sample)
obs = torch.tensor(data)
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": obs})
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
return logits + inf_mask, state
def value_function(self):
return self.internal_model.value_function()
class MyEnv(gymnasium.Env):
metadata = {"render.modes": ["human"]}
def __init__(self):
super(MyEnv, self).__init__()
self.actions = 4
self.action_space = Discrete(self.actions)
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(self.actions, )),
"actual_obs":
Dict({
"obs1": Box(low=-np.inf, high=np.inf, shape=(10, 10), dtype=np.float32),
"obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10), dtype=np.float32),
}),
})
def reset(self, *, seed=None, options=None):
return self._make_obs(), {}
def step(self, action):
return self._make_obs(), 0, False, False, {}
def _make_obs(self):
return {
"action_mask": np.array([1.0] * self.actions),
"actual_obs": {
"obs1": np.zeros((10, 10), dtype=np.float32),
"obs2": np.zeros((10, 10), dtype=np.float32)
},
}
def main():
ray.init()
select_env = "env-v1"
register_env(select_env, lambda config: MyEnv())
config = ApexDQNConfig().framework('torch') \
.training(
model = {
"custom_model": TorchActionMaskModel,
"no_final_linear": False
}, train_batch_size=32,
hiddens=[],
dueling=False
) \
.environment(select_env).build()
algo = config.build()
for _ in range(5):
algo.train()
if __name__ == "__main__":
main()