Best Practices for Implementing a Shared Critic?

Hello, all. I’ve been working on a multi-agent reinforcement learning project where a shared critic would be very useful. In brief, my aims are:

  • Store a shared critic module in a MultiRLModule, which calculates the value of a global observation.
  • Modify our learner or algorithm such that only one forward and backward pass through the value network is performed, even when updating potentially dozens of policy networks.

Looking through the codebase, I found this example of a MultiRLModule with a shared encoder, which I thought would be a good starting point. Unfortunately, the included example code doesn’t seem to work. After making a few straightforward fixes (for instance, adding a line that specifies the MultiRLModule subclass in the config), I got it to what I think is the intended state (see below).

import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.examples.rl_modules.classes.vpg_using_shared_encoder_rlm import (
    SHARED_ENCODER_ID,
    SharedEncoder,
    VPGPolicyAfterSharedEncoder,
    VPGMultiRLModuleWithSharedEncoder,
)

single_agent_env = gym.make("CartPole-v1")

EMBEDDING_DIM = 64  # encoder output dim

config = (
    PPOConfig()
    .environment(MultiAgentCartPole, env_config={"num_agents": 2})
    .env_runners(
        num_env_runners=0,
        num_envs_per_env_runner=1,
    )
    .multi_agent(
        # Declare the two policies trained.
        policies={"p0", "p1"},
        # Agent IDs of `MultiAgentCartPole` are 0 and 1. They are mapped to
        # the two policies with ModuleIDs "p0" and "p1", respectively.
        policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            multi_rl_module_class=VPGMultiRLModuleWithSharedEncoder,
            rl_module_specs={
                # Shared encoder.
                SHARED_ENCODER_ID: RLModuleSpec(
                    module_class=SharedEncoder,
                    model_config={"embedding_dim": EMBEDDING_DIM},
                    observation_space=single_agent_env.observation_space,
                ),
                # Large policy net.
                "p0": RLModuleSpec(
                    module_class=VPGPolicyAfterSharedEncoder,
                    model_config={
                        "embedding_dim": EMBEDDING_DIM,
                        "hidden_dim": 1024,
                    },
                ),
                # Small policy net.
                "p1": RLModuleSpec(
                    module_class=VPGPolicyAfterSharedEncoder,
                    model_config={
                        "embedding_dim": EMBEDDING_DIM,
                        "hidden_dim": 64,
                    },
                ),
            },
        ),
    )
)
algo = config.build_algo()
print(algo.get_module())

Even so, there appears to be a fundamental problem: algorithm_config.py seems to expect every module specified in rl_module_specs to be associated with a policy, leading to this error on line 4598:

KeyError: 'default_policy'

The strategy I’d been hoping to follow was overriding compute_loss_for_module in a custom PPOTorchLearner to only optimize the policy networks of individual actors, and then optimize the value network of the shared critic separately, distinguishing between the agents and the shared critic using module ID. I’d then override build to use a custom GAE method that would use the shared critic to calculate advantage.

I suspect, based on the error above, that this isn’t what was intended, but I’m having trouble figuring out the most elegant way to do what I’m trying to do. There’s a placeholder in /examples for a shared critic example for the new API stack, but it’s not yet available on the main repo. I’d be glad to lend a hand on that front, if I could be of any use.

Hey @MCW_Lad , thanks for raising this. This is indeed not 100% ok in our code and should be fixed (or a better error should be provided).

However, you can solve your problem and make the script work by simply specifying the action_space as well for the shared_encoder (ignore the fact that a shared encoder doesn’t need an action space for now):

So in your code, just change to:

...
rl_module_specs={
    SHARED_ENCODER_ID: RLModuleSpec(
    ...
    observation_space=single_agent_env.observation_space,
    action_space=single_agent_env.action_space,  # <---- ADD THIS HERE!!
),
1 Like

Sorry for the late reply - I get the same error as before when I run the modified code. I missed some imports in the example code I pasted in last time, so for the sake of posterity, here’s the code I’m presently running:

# !pip install -q "ray[rllib]" pettingzoo

import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.examples.rl_modules.classes.vpg_using_shared_encoder_rlm import (
    SHARED_ENCODER_ID,
    SharedEncoder,
    VPGPolicyAfterSharedEncoder,
    VPGMultiRLModuleWithSharedEncoder,
)

single_agent_env = gym.make("CartPole-v1")

EMBEDDING_DIM = 64  # encoder output dim

config = (
    PPOConfig()
    .environment(MultiAgentCartPole, env_config={"num_agents": 2})
    .env_runners(
        num_env_runners=0,
        num_envs_per_env_runner=1,
    )
    .multi_agent(
        # Declare the two policies trained.
        policies={"p0", "p1"},
        # Agent IDs of `MultiAgentCartPole` are 0 and 1. They are mapped to
        # the two policies with ModuleIDs "p0" and "p1", respectively.
        policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            multi_rl_module_class=VPGMultiRLModuleWithSharedEncoder,
            rl_module_specs={
                # Shared encoder.
                SHARED_ENCODER_ID: RLModuleSpec(
                    module_class=SharedEncoder,
                    model_config={"embedding_dim": EMBEDDING_DIM},
                    observation_space=single_agent_env.observation_space,
                    action_space=single_agent_env.action_space,  # <-- Added
                ),
                # Large policy net.
                "p0": RLModuleSpec(
                    module_class=VPGPolicyAfterSharedEncoder,
                    model_config={
                        "embedding_dim": EMBEDDING_DIM,
                        "hidden_dim": 1024,
                    },
                ),
                # Small policy net.
                "p1": RLModuleSpec(
                    module_class=VPGPolicyAfterSharedEncoder,
                    model_config={
                        "embedding_dim": EMBEDDING_DIM,
                        "hidden_dim": 64,
                    },
                ),
            },
        ),
    )
)
algo = config.build_algo()
print(algo.get_module())

And here’s the output:

2025-06-22 07:25:46,819	WARNING algorithm_config.py:5014 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
2025-06-22 07:25:49,876	INFO worker.py:1917 -- Started a local Ray instance.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/tmp/ipython-input-2-2864521620.py in <cell line: 0>()
     60     )
     61 )
---> 62 algo = config.build_algo()
     63 print(algo.get_module())

9 frames
/usr/local/lib/python3.11/dist-packages/ray/rllib/algorithms/algorithm_config.py in get_multi_rl_module_spec(self, env, spaces, inference_only, policy_dict, single_agent_rl_module_spec)
   4595             policy_spec = policy_dict.get(module_id)
   4596             if policy_spec is None:
-> 4597                 policy_spec = policy_dict[DEFAULT_MODULE_ID]
   4598 
   4599             if module_spec.module_class is None:

KeyError: 'default_policy'

It should replicate on a fresh Colab instance, with the commented out line installing the dependencies.


Debugging:

policy_dict gets to us as follows:

{
    'p0': <ray.rllib.policy.policy.PolicySpec object at 0x7bb134495050>,
    'p1': <ray.rllib.policy.policy.PolicySpec object at 0x7bb134495610>}

It’s populated in get_multi_agent_setup, which looks at the config’s set of policies rather than its set of RL modules when populating it.

I should note that the only thing that policy_spec gets used for is grabbing the observation and action spaces, but it tries to get the relevant item before checking whether they’re already specified, so the error throws anyways. I think the immediate issue could be fixed with the following change to get_multi_rl_module_spec:

get_multi_rl_module_spec

def get_multi_rl_module_spec(
self,
*,
env: Optional[EnvType] = None,
spaces: Optional[Dict[PolicyID, Tuple[gym.Space, gym.Space]]] = None,
inference_only: bool = False,
# @HybridAPIStack
policy_dict: Optional[Dict[str, PolicySpec]] = None,
single_agent_rl_module_spec: Optional[RLModuleSpec] = None,
) → MultiRLModuleSpec:
“”"Returns the MultiRLModuleSpec based on the given env/spaces.

    Args:
        env: An optional environment instance, from which to infer the different
            spaces for the individual RLModules. If not provided, tries to infer
            from `spaces`, otherwise from `self.observation_space` and
            `self.action_space`. Raises an error, if no information on spaces can be
            inferred.
        spaces: Optional dict mapping ModuleIDs to 2-tuples of observation- and
            action space that should be used for the respective RLModule.
            These spaces are usually provided by an already instantiated remote
            EnvRunner (call `EnvRunner.get_spaces()`). If not provided, tries
            to infer from `env`, otherwise from `self.observation_space` and
            `self.action_space`. Raises an error, if no information on spaces can be
            inferred.
        inference_only: If `True`, the returned module spec is used in an
            inference-only setting (sampling) and the RLModule can thus be built in
            its light version (if available). For example, the `inference_only`
            version of an RLModule might only contain the networks required for
            computing actions, but misses additional target- or critic networks.
            Also, if `True`, the returned spec does NOT contain those (sub)
            RLModuleSpecs that have their `learner_only` flag set to True.

    Returns:
        A new MultiRLModuleSpec instance that can be used to build a MultiRLModule.
    """
    # TODO (Kourosh,sven): When we replace policy entirely there is no need for
    #  this function to map policy_dict to multi_rl_module_specs anymore. The module
    #  spec is directly given by the user or inferred from env and spaces.
    if policy_dict is None:
        policy_dict, _ = self.get_multi_agent_setup(env=env, spaces=spaces)

    # TODO (Kourosh): Raise an error if the config is not frozen
    # If the module is single-agent convert it to multi-agent spec

    # The default RLModuleSpec (might be multi-agent or single-agent).
    default_rl_module_spec = self.get_default_rl_module_spec()
    # The currently configured RLModuleSpec (might be multi-agent or single-agent).
    # If None, use the default one.
    current_rl_module_spec = self._rl_module_spec or default_rl_module_spec

    # Algorithm is currently setup as a single-agent one.
    if isinstance(current_rl_module_spec, RLModuleSpec):
        # Use either the provided `single_agent_rl_module_spec` (a
        # RLModuleSpec), the currently configured one of this
        # AlgorithmConfig object, or the default one.
        single_agent_rl_module_spec = (
            single_agent_rl_module_spec or current_rl_module_spec
        )
        single_agent_rl_module_spec.inference_only = inference_only
        # Now construct the proper MultiRLModuleSpec.
        multi_rl_module_spec = MultiRLModuleSpec(
            rl_module_specs={
                k: copy.deepcopy(single_agent_rl_module_spec)
                for k in policy_dict.keys()
            },
        )

    # Algorithm is currently setup as a multi-agent one.
    else:
        # The user currently has a MultiAgentSpec setup (either via
        # self._rl_module_spec or the default spec of this AlgorithmConfig).
        assert isinstance(current_rl_module_spec, MultiRLModuleSpec)

        # Default is single-agent but the user has provided a multi-agent spec
        # so the use-case is multi-agent.
        if isinstance(default_rl_module_spec, RLModuleSpec):
            # The individual (single-agent) module specs are defined by the user
            # in the currently setup MultiRLModuleSpec -> Use that
            # RLModuleSpec.
            if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec):
                single_agent_spec = single_agent_rl_module_spec or (
                    current_rl_module_spec.rl_module_specs
                )
                single_agent_spec.inference_only = inference_only
                module_specs = {
                    k: copy.deepcopy(single_agent_spec) for k in policy_dict.keys()
                }

            # The individual (single-agent) module specs have not been configured
            # via this AlgorithmConfig object -> Use provided single-agent spec or
            # the default spec (which is also a RLModuleSpec in this
            # case).
            else:
                single_agent_spec = (
                    single_agent_rl_module_spec or default_rl_module_spec
                )
                single_agent_spec.inference_only = inference_only
                module_specs = {
                    k: copy.deepcopy(
                        current_rl_module_spec.rl_module_specs.get(
                            k, single_agent_spec
                        )
                    )
                    for k in (
                        policy_dict | current_rl_module_spec.rl_module_specs
                    ).keys()
                }

            # Now construct the proper MultiRLModuleSpec.
            # We need to infer the multi-agent class from `current_rl_module_spec`
            # and fill in the module_specs dict.
            multi_rl_module_spec = current_rl_module_spec.__class__(
                multi_rl_module_class=current_rl_module_spec.multi_rl_module_class,
                rl_module_specs=module_specs,
                modules_to_load=current_rl_module_spec.modules_to_load,
                load_state_path=current_rl_module_spec.load_state_path,
            )

        # Default is multi-agent and user wants to override it -> Don't use the
        # default.
        else:
            # User provided an override RLModuleSpec -> Use this to
            # construct the individual RLModules within the MultiRLModuleSpec.
            if single_agent_rl_module_spec is not None:
                pass
            # User has NOT provided an override RLModuleSpec.
            else:
                # But the currently setup multi-agent spec has a SingleAgentRLModule
                # spec defined -> Use that to construct the individual RLModules
                # within the MultiRLModuleSpec.
                if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec):
                    # The individual module specs are not given, it is given as one
                    # RLModuleSpec to be re-used for all
                    single_agent_rl_module_spec = (
                        current_rl_module_spec.rl_module_specs
                    )
                # The currently set up multi-agent spec has NO
                # RLModuleSpec in it -> Error (there is no way we can
                # infer this information from anywhere at this point).
                else:
                    raise ValueError(
                        "We have a MultiRLModuleSpec "
                        f"({current_rl_module_spec}), but no "
                        "`RLModuleSpec`s to compile the individual "
                        "RLModules' specs! Use "
                        "`AlgorithmConfig.get_multi_rl_module_spec("
                        "policy_dict=.., rl_module_spec=..)`."
                    )

            single_agent_rl_module_spec.inference_only = inference_only

            # Now construct the proper MultiRLModuleSpec.
            multi_rl_module_spec = current_rl_module_spec.__class__(
                multi_rl_module_class=current_rl_module_spec.multi_rl_module_class,
                rl_module_specs={
                    k: copy.deepcopy(single_agent_rl_module_spec)
                    for k in policy_dict.keys()
                },
                modules_to_load=current_rl_module_spec.modules_to_load,
                load_state_path=current_rl_module_spec.load_state_path,
            )

    # Fill in the missing values from the specs that we already have. By combining
    # PolicySpecs and the default RLModuleSpec.
    for module_id in policy_dict | multi_rl_module_spec.rl_module_specs:

        # Remove/skip `learner_only=True` RLModules if `inference_only` is True.
        module_spec = multi_rl_module_spec.rl_module_specs[module_id]
        if inference_only and module_spec.learner_only:
            multi_rl_module_spec.remove_modules(module_id)
            continue
            
        # MCW: Removed code here.

        if module_spec.module_class is None:
            if isinstance(default_rl_module_spec, RLModuleSpec):
                module_spec.module_class = default_rl_module_spec.module_class
            elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec):
                module_class = default_rl_module_spec.rl_module_specs.module_class
                # This should be already checked in validate() but we check it
                # again here just in case
                if module_class is None:
                    raise ValueError(
                        "The default rl_module spec cannot have an empty "
                        "module_class under its RLModuleSpec."
                    )
                module_spec.module_class = module_class
            elif module_id in default_rl_module_spec.rl_module_specs:
                module_spec.module_class = default_rl_module_spec.rl_module_specs[
                    module_id
                ].module_class
            else:
                raise ValueError(
                    f"Module class for module {module_id} cannot be inferred. "
                    f"It is neither provided in the rl_module_spec that "
                    "is passed in nor in the default module spec used in "
                    "the algorithm."
                )
        if module_spec.catalog_class is None:
            if isinstance(default_rl_module_spec, RLModuleSpec):
                module_spec.catalog_class = default_rl_module_spec.catalog_class
            elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec):
                catalog_class = default_rl_module_spec.rl_module_specs.catalog_class
                module_spec.catalog_class = catalog_class
            elif module_id in default_rl_module_spec.rl_module_specs:
                module_spec.catalog_class = default_rl_module_spec.rl_module_specs[
                    module_id
                ].catalog_class
            else:
                raise ValueError(
                    f"Catalog class for module {module_id} cannot be inferred. "
                    f"It is neither provided in the rl_module_spec that "
                    "is passed in nor in the default module spec used in "
                    "the algorithm."
                )
        # TODO (sven): Find a good way to pack module specific parameters from
        # the algorithms into the `model_config_dict`.
        # MCW: modified the below lines
        if (module_spec.observation_space is None or module_spec.action_space is None):
            policy_spec = policy_dict.get(module_id)
            if policy_spec is None:
                policy_spec = policy_dict[DEFAULT_MODULE_ID]
            module_spec.observation_space = policy_spec.observation_space
            module_spec.action_space = policy_spec.action_space
        # End of modified code
            
        # In case the `RLModuleSpec` does not have a model config dict, we use the
        # the one defined by the auto keys and the `model_config_dict` arguments in
        # `self.rl_module()`.
        if module_spec.model_config is None:
            module_spec.model_config = self.model_config
        # Otherwise we combine the two dictionaries where settings from the
        # `RLModuleSpec` have higher priority.
        else:
            module_spec.model_config = (
                self.model_config | module_spec._get_model_config()
            )

    return multi_rl_module_spec

There are some errors later on (when algo.train() is run) that I’d be willing to take a shot at documenting or sorting out if it would be of use. Alternatively, if there’s an intended strategy for getting a centralized critic to work, I’d be glad to try my hand at implementing and documenting it for the /examples section.

@sven1977 I had a bit of time to work on this, and I’ve got the shared encoder running now. I did have to change a few things in the supporting classes, I think a few files relied on deprecated paths and inheritance behavior. I also created a subclass of VPGTorchLearner that centralizes the optimizer, as described in your notes, and added a script in /examples that demonstrates the whole system working on MultiAgentCartPole.

I submitted a PR at Shared encoder working by MatthewCWeston · Pull Request #54571 · ray-project/ray · GitHub; if you’ve got time, I’d be thankful if you’d let me know if my code looks right to you.