MeanStdFilter Observation filter also normalizes action mask

I’m attempting to use the MeanStdFilter observation filter with an environment that uses action masking and I believe the filter is also normalizing the action mask. I’m using ray 0.8.5 with tensorflow 1.15.4. Here is a script to recreate the issue:

import argparse
import random
import numpy as np
import gym
from gym.spaces import Box, Discrete, Dict, Tuple

import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.tune.registry import register_env
from ray.rllib.utils import try_import_tf
import ray.rllib.agents.ppo as ppo

tf = try_import_tf()

class ActionMaskingCartpole(gym.Env):

    def __init__(self):
        self.action_space = Tuple([Discrete(2), Discrete(5)])
        self.wrapped = gym.make("CartPole-v0")

        self.observation_space = Dict({
            "action_mask": Tuple([Box(0, 1, shape=(x.n,)) for x in self.action_space.spaces]),
            "state": self.wrapped.observation_space,
        })

    def update_avail_actions(self):
        pass

    def reset(self):
        self.update_avail_actions()
        
        return {
            "action_mask": [np.array([1.] * x.n) for x in self.action_space.spaces],
            "state": self.wrapped.reset(),
        }

    def step(self, action):
        actual_action = action[0]

        orig_obs, rew, done, info = self.wrapped.step(actual_action)
        self.update_avail_actions()

        obs = {
            "action_mask": [np.array([1.] * x.n) for x in self.action_space.spaces],
            "state": orig_obs,
        }
        return obs, rew, done, info

class ActionMaskingModel(TFModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4,), **kw):
        super(ActionMaskingModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name, **kw)

        self.action_embed_model = FullyConnectedNetwork(
            Box(np.finfo(np.float32).min, np.finfo(np.float32).max, shape=true_obs_shape), 
            action_space,
            num_outputs, 
            model_config, 
            name
        )
        
        self.register_variables(self.action_embed_model.variables())
        
    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = tf.cast(tf.concat(input_dict["obs"]["action_mask"], axis=1), tf.float32)
        
        # Compute the predicted action embedding
        action_embedding, _ = self.action_embed_model({"obs": input_dict["obs"]["state"]})
 
        # Mask out invalid actions (use tf.float32.min for stability)
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        return action_embedding + inf_mask, state
 
    def value_function(self):
        return self.action_embed_model.value_function()
    
if __name__ == "__main__":
    ray.init()

    ModelCatalog.register_custom_model("ActionMaskingModel", ActionMaskingModel)
    register_env("ActionMaskingCartpole", lambda _: ActionMaskingCartpole())
            
    tune.run(
        "PPO",
        stop={
            "training_iteration": 5,
        },
        config={
            "env": "ActionMaskingCartpole",
            "model": {
                "custom_model": "ActionMaskingModel",
            },
            "num_workers": 0,
            "observation_filter" : "MeanStdFilter"
        },
    )

The mask is set to constant 1’s so it should not have any impact on training. When you remove the MeanStdFilter from the config file, the model converges to a reward ~100 after 5 iterations however when you include the filter, the reward does not converge and remains ~20. Is there another recommended way I am supposed to implement action masking (i.e. some sort of wrapper around the environment to add another layer of abstraction)?

Thanks for posting this question, @pgigioli !
Yes, the StdFilter probably normalizes the action mask as well :slight_smile:
I don’t see any good workarounds for this problem right now:

  • RLlib filters are hard-coded :confused: only allowing “NoFilter”, “MeanStdFilter”, etc…
  • What you could do is to use “NoFilter” and then simply write your own gym observation wrapper, that does mean std on just the “state” observation.