Autoregressive Continuous Actions

Hello,

I would like to use the example in the link below for an environment that requires autoregressive continuous actions. Specifically, the agent selects two continuous actions from a Gaussian distribution, with the second dependent on the first. Is there support for this out there anywhere? If not, could someone give me some tips on implementing it?

https://docs.ray.io/en/releases-2.2.0/rllib/rllib-models.html#autoregressive-action-distributions

Here is my attempt to adapt the example to a continuous environment. The script runs, but the agent never learns.

import argparse
import os
import random

import gymnasium as gym
import ray
from gymnasium.spaces import Discrete, Tuple, Box
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.misc import normc_initializer as normc_init_torch
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
                                                      TorchDiagGaussian,
                                                      TorchDistributionWrapper)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.logger import pretty_print

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()


def get_cli_args():
    """Create CLI parser and return parsed arguments"""
    parser = argparse.ArgumentParser()

    # example-specific arg: disable autoregressive action dist
    parser.add_argument(
        "--no-autoreg",
        action="store_true",
        help="Do NOT use an autoregressive action distribution but normal,"
        "independently distributed actions.",
    )

    # general args
    parser.add_argument(
        "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
    )
    parser.add_argument(
        "--framework",
        choices=["tf", "tf2", "tfe", "torch"],
        default="torch",
        help="The DL framework specifier.",
    )
    parser.add_argument("--num-cpus", type=int, default=0)
    parser.add_argument(
        "--as-test",
        action="store_true",
        help="Whether this script should be run as a test: --stop-reward must "
        "be achieved within --stop-timesteps AND --stop-iters.",
    )
    parser.add_argument(
        "--stop-iters", type=int, default=200, help="Number of iterations to train."
    )
    parser.add_argument(
        "--stop-timesteps",
        type=int,
        default=100000,
        help="Number of timesteps to train.",
    )
    parser.add_argument(
        "--stop-reward",
        type=float,
        default=200.0,
        help="Reward at which we stop training.",
    )
    parser.add_argument(
        "--no-tune",
        action="store_true",
        help="Run without Tune using a manual train loop instead. Here,"
        "there is no TensorBoard support.",
    )
    parser.add_argument(
        "--local-mode",
        action="store_true",
        help="Init Ray in local mode for easier debugging.",
    )

    args = parser.parse_args()
    print(f"Running with following CLI args: {args}")
    return args

class CorrelatedActionsEnv(gym.Env):
    """
    Simple env in which the policy has to emit a tuple of equal actions.

    In each step, the agent observes a random number (0 or 1) and has to choose
    two actions a1 and a2.
    It gets +5 reward for matching a1 to the random obs and +5 for matching a2
    to a1. I.e., +10 at most per step.

    One way to effectively learn this is through correlated action
    distributions, e.g., in examples/autoregressive_action_dist.py

    There are 20 steps. Hence, the best score would be ~200 reward.
    """

    def __init__(self, _):
        self.observation_space = Discrete(2)
        self.action_space = Tuple([Box(0, 1), Box(0, 1)])
        self.last_observation = None

    def reset(self, *, seed=None, options=None):
        self.t = 0
        self.last_observation = random.choice([0, 1])
        return self.last_observation, {}

    def step(self, action):
        self.t += 1
        a1, a2 = action
        reward = 0
        # NOTE: Reward function modified
        # Encourage correlation between most recent observation and a1.
        reward -= abs(self.last_observation - a1).item()
        # if a1 == self.last_observation:
        #     reward += 5
        # Encourage correlation between a1 and a2.
        reward -= abs(a1 - a2).item()
        # if a1 == a2:
        #     reward += 5
        done = truncated = self.t > 20
        self.last_observation = random.choice([0, 1])
        return self.last_observation, reward, done, truncated, {}

class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
    """
    PyTorch version of the AutoregressiveActionModel above.
    """ 

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        # Output of the model (normally 'logits', but for an autoregressive
        # dist this is more like a context/feature layer encoding the obs)
        self.context_layer = SlimFC(
            in_size=obs_space.shape[0],
            out_size=num_outputs,
            initializer=normc_init_torch(1.0),
            activation_fn=nn.Tanh,
        )

        # V(s)
        self.value_branch = SlimFC(
            in_size=num_outputs,
            out_size=1,
            initializer=normc_init_torch(0.01),
            activation_fn=None,
        )

        # P(a1 | obs)
        self.a1_logits = SlimFC(
            in_size=num_outputs,
            out_size=2,
            activation_fn=None,
            initializer=normc_init_torch(0.01),
        )

        class _ActionModel(nn.Module):
            def __init__(self):
                nn.Module.__init__(self)
                self.a2_hidden = SlimFC(
                    in_size=1,
                    out_size=16,
                    activation_fn=nn.Tanh,
                    initializer=normc_init_torch(1.0),
                )
                self.a2_logits = SlimFC(
                    in_size=16,
                    out_size=2,
                    activation_fn=None,
                    initializer=normc_init_torch(0.01),
                )

            def forward(self_, ctx_input, a1_input):
                a1_logits = self.a1_logits(ctx_input)
                a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
                return a1_logits, a2_logits

        # P(a2 | a1)
        # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
        # a2_context = tf.keras.layers.Concatenate(axis=1)(
        #     [ctx_input, a1_input])
        self.action_module = _ActionModel()

        self._context = None

    def forward(self, input_dict, state, seq_lens):
        self._context = self.context_layer(input_dict["obs"])
        return self._context, state

    def value_function(self):
        return torch.reshape(self.value_branch(self._context), [-1])

class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
    """Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""

    def deterministic_sample(self):
        # First, sample a1.
        a1_dist = self._a1_distribution()
        a1 = a1_dist.deterministic_sample()

        # Sample a2 conditioned on a1.
        a2_dist = self._a2_distribution(a1)
        a2 = a2_dist.deterministic_sample()
        self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)

        # Return the action tuple.
        return (a1, a2)

    def sample(self):
        # First, sample a1.
        a1_dist = self._a1_distribution()
        a1 = a1_dist.sample()

        # Sample a2 conditioned on a1.
        a2_dist = self._a2_distribution(a1)
        a2 = a2_dist.sample()
        self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)

        # Return the action tuple.
        return (a1, a2)

    def logp(self, actions):
        a1, a2 = actions[:, 0], actions[:, 1]
        a1_vec = torch.unsqueeze(a1.float(), 1)
        a1_logits, a2_logits = self.model.action_module(self.inputs, a1_vec)
        return TorchDiagGaussian(a1_logits, None).logp(a1) + TorchDiagGaussian(a2_logits, None).logp(
            a2
        )

    def sampled_action_logp(self):
        return torch.exp(self._action_logp)

    def entropy(self):
        a1_dist = self._a1_distribution()
        a2_dist = self._a2_distribution(a1_dist.sample())
        return a1_dist.entropy() + a2_dist.entropy()

    def kl(self, other):
        a1_dist = self._a1_distribution()
        a1_terms = a1_dist.kl(other._a1_distribution())

        a1 = a1_dist.sample()
        a2_terms = self._a2_distribution(a1).kl(other._a2_distribution(a1))
        return a1_terms + a2_terms

    def _a1_distribution(self):
        BATCH = self.inputs.shape[0]
        zeros = torch.zeros((BATCH, 1)).to(self.inputs.device)
        a1_logits, _ = self.model.action_module(self.inputs, zeros)
        a1_dist = TorchDiagGaussian(a1_logits, None)
        return a1_dist

    def _a2_distribution(self, a1):
        # a1_vec = torch.unsqueeze(a1.float(), 1)
        _, a2_logits = self.model.action_module(self.inputs, a1)
        a2_dist = TorchDiagGaussian(a2_logits, None)
        return a2_dist

    @staticmethod
    def required_model_output_shape(action_space, model_config):
        return 16  # controls model output feature vector size

if __name__ == "__main__":
    args = get_cli_args()
    ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

    # main part: register and configure autoregressive action model and dist
    # here, tailored to the CorrelatedActionsEnv such that a2 depends on a1
    ModelCatalog.register_custom_model(
        "autoregressive_model",
        TorchAutoregressiveActionModel
    )
    ModelCatalog.register_custom_action_dist(
        "binary_autoreg_dist",
        TorchBinaryAutoregressiveDistribution
    )

    # standard config
    config = {
        "gamma": 0.5,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "framework": args.framework,
        # NOTE: Uncomment for debugging
        # "num_workers": 0,
        "num_rollout_workers": 10,
    }
    # use registered model and dist in config
    if not args.no_autoreg:
        config["model"] = {
            "custom_model": "autoregressive_model",
            "custom_action_dist": "binary_autoreg_dist",
        }

    # use stop conditions passed via CLI (or defaults)
    stop = {
        "training_iteration": args.stop_iters,
        "timesteps_total": args.stop_timesteps,
        "episode_reward_mean": args.stop_reward,
    }

    # manual training loop using PPO without tune.run()

    ppo_config = ppo.DEFAULT_CONFIG.copy()
    ppo_config.update(config)
    trainer = ppo.PPOTrainer(config=ppo_config, env=CorrelatedActionsEnv)
    # run manual training loop and print results after each iteration
    for _ in range(args.stop_iters):
        result = trainer.train()
        print(pretty_print(result))
        # stop training if the target train steps or reward are reached
        if (
            result["timesteps_total"] >= args.stop_timesteps
            or result["episode_reward_mean"] >= args.stop_reward
        ):
            break
        
    # run manual test loop: 1 iteration until done
    print("Finished training. Running manual test/inference loop.")
    env = CorrelatedActionsEnv(_)
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        a1, a2 = trainer.compute_single_action(obs)
        next_obs, reward, done, _ = env.step((a1, a2))
        print(f"Obs: {obs}, Action: a1={a1} a2={a2}, Reward: {reward}")
        obs = next_obs
        total_reward += reward
    print(f"Total reward in test episode: {total_reward}")

    ray.shutdown()