Observation dependent continuous action space ("Masking" continuous action space)

Hi everyone,

I have a custom Gym environment with an m-dimensional continuous observation space and an n-dimensional continuous action space.

e.g.

class CustomEnv(gym.Env):

  def __init__(self):
    super(CustomEnv, self).__init__()

    self.action_space = spaces.Box(low=np.finfo(np.float32).min, 
                                   high=np.finfo(np.float32).max, 
                                   shape=(2,))

    self.observation_space = spaces.Box(low=np.array([-5, 10, 0]), 
                                        high=np.array([2, 40, 22]),
                                        shape=(3,))

The action space should depend on the current observation in each period, i.e. the RL algorithm should only sample an action between some bounds that vary with the current observation.

For example:

l1(obs_t) < a1_t < u1(obs_t)
l2(obs_t) < a2_t < u2(obs_t)

where l1 and l2 determine the lower bound and u1 and u2 determine the upper bound.

My idea would be to write a custom action distribution, but how do I get the information of the current state passed to my custom action distribution? Or is there a better way to accomplish an observation dependent action space?

Any help is really appreciated.

I use the PPO algorithm.

Hi
for discrete actions as mentioned in this example you should pass available action and mask action from your env to your model.
then your model make masked actions logit’s zero.
remember: in Policy based algorithms we sample actions from distributions (as far as I know). so when you turn off some actions logit, its not available in distribution to sample.
questions that I faced with:

  1. is there any other way to do that?
    maybe you can manipulate sample function instead of logits ( so that it doesn’t sample invalid actions).

  2. if method introduced in q1 is valid why they manipulated logits?
    I guess because they started with DQN algorithm and since in critic based algorithms we do not sample actions they had no other way to mask actions and in discrete action spaces its also possible to turn off some neurons to delete that action from action distribution ( which we want to sample from) that method also works in this criteria.

**SO what about continuous action spaces **
well I didn’t see any example for this problem but from what I read and understand Rllib , you should manipulate sample function. its more similar to autoregressive-action-distributions than parametric action space example. maybe you should pass available actions (from env) to some ‘ActionDistribution’ class and specify action available range. for e.g action would be sampled from 0 to X (you passed X to this class) range.
I’m not very sure about this method but you can try.
hope this helps

3 Likes

This topic was automatically closed 24 hours after the last reply. New replies are no longer allowed.

Hey @thgehr , thanks for posting this. Seems like a pretty unusual case for a specific exploration behavior, very interesting!

Yeah, @hossein836 is right, you could create a new Policy (based off PPOPolicy) and override the action_sampler_fn. Let me try to come up with a quick example and post it here …

1 Like

You can take a look at the Dreamer algo, which is currently the only built-in one that uses the customized action_sampler_fn: ray.rllib.agents.dreamer.dreamer_torch_policy.py::action_sampler_fn()

In case you are on tf, you need to change the function’s signature a little to:

action_sampler_fn(policy, model, *, obs_batch, state_batches, seq_lens, prev_action_batch, prev_reward_batch, explore, is_training):
    ...
    return [some action], [some action's logp (may be dummy 0.0s or 1.0s)]

Then plug this into a new Policy and Trainer like so:

MyPolicy = PPO[Torch|TF]Policy.with_updates(
  action_sampler_fn=[your custom sampler function from above]
)

# Use this new policy class in your trainer:
class MyTrainer(PPOTrainer):
  def get_default_policy_class:
    return MyPolicy

trainer = MyTrainer(config=[some PPO-config], ...)
trainer.train()
2 Likes