[New API stack] Epsilon Greedy exploration example

I learn from the New API stack migration guide — Ray 3.0.0.dev0 , that following handy way of configuring an epsilon greedy exploration is no longer possible in the new API stack.

"exploration_config": {

         "type": "EpsilonGreedy",

         "epsilon_timesteps": 1000,

         "final_epsilon": 0.01,

},

However, is there any example reference for exploration functions in the new RLModule syntax? Would be great if the users can get an example, at least for the quite common epsilon greedy variant.

You are correct: in the new RLlib API stack, the exploration_config setting (e.g., for EpsilonGreedy) is deprecated and cannot be used. Instead, you must implement custom exploration directly by overriding the _forward_exploration() method in your RLModule subclass. This method defines how actions are sampled during exploration, and is the new, recommended way to control exploration behavior—including epsilon-greedy—under the RLModule API. There is no built-in example for epsilon-greedy in the RLModule context in the official migration guide, but the approach is to implement the epsilon-greedy logic yourself inside _forward_exploration() of your RLModule class, using the current timestep to anneal epsilon as needed. See the migration guide for confirmation and further explanation.

Would you like a step-by-step code example of how to implement epsilon-greedy exploration in a custom RLModule?

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM Yes, a step-by-step code example of how to implement epsilon-greedy exploration in a custom RLModule is what I am looking for.

Certainly! To implement epsilon-greedy exploration in the new RLModule API, subclass TorchRLModule and override the _forward_exploration method. Inside this method, compute Q-values, sample a random action with probability epsilon, and otherwise select the greedy action. You can use a schedule for epsilon decay as needed. Here’s a minimal example:

import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule

class EpsilonGreedyDQNModule(TorchRLModule):
    def setup(self):
        input_dim = self.observation_space.shape[0]
        output_dim = self.action_space.n
        hidden_dim = self.model_config.get("fcnet_hiddens", [128])[0]
        self.q_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
        self.epsilon = 1.0  # initial epsilon
        self.final_epsilon = 0.01
        self.epsilon_timesteps = 1000

    def _forward_exploration(self, batch, t):
        q_values = self.q_net(batch["obs"])
        greedy_actions = torch.argmax(q_values, dim=1)
        batch_size = q_values.shape[0]
        # Linear epsilon decay
        epsilon = max(
            self.final_epsilon,
            self.epsilon - (self.epsilon - self.final_epsilon) * min(t, self.epsilon_timesteps) / self.epsilon_timesteps,
        )
        random_actions = torch.randint(0, self.action_space.n, (batch_size,))
        explore_mask = torch.rand(batch_size) < epsilon
        actions = torch.where(explore_mask, random_actions, greedy_actions)
        return {"actions": actions}

This pattern is described in the RLModule documentation and is the recommended way to implement custom exploration in the new API stack.

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM : The links to the RLModule documentation lead to FileNotFound error. Provide another code example for PPO and discrete action-masking based on ray/rllib/examples/rl_modules/classes/action_masking_rlm.py at master · ray-project/ray · GitHub .

Certainly! For PPO with discrete action masking in the new RLModule API, you should subclass ActionMaskingTorchRLModule and override the _forward_exploration method to add epsilon-greedy exploration. The key is to mask invalid actions and then sample randomly with probability epsilon, or take the greedy action otherwise. Here’s a step-by-step code example based on action_masking_rlm.py:

import torch
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.utils.torch_utils import FLOAT_MIN

class EpsilonGreedyActionMaskingTorchRLModule(ActionMaskingTorchRLModule):
    def setup(self):
        super().setup()
        self.epsilon = 1.0
        self.final_epsilon = 0.01
        self.epsilon_timesteps = 1000
        self.timestep = 0

    def _forward_exploration(self, batch, **kwargs):
        # Standard action masking logic
        action_mask, batch = self._preprocess_batch(batch)
        outs = super()._forward_exploration(batch, **kwargs)
        logits = outs[Columns.ACTION_DIST_INPUTS]
        # Mask invalid actions
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        masked_logits = logits + inf_mask

        # Epsilon-greedy logic
        batch_size = masked_logits.shape[0]
        epsilon = max(
            self.final_epsilon,
            self.epsilon - (self.epsilon - self.final_epsilon) * min(self.timestep, self.epsilon_timesteps) / self.epsilon_timesteps,
        )
        self.timestep += batch_size

        greedy_actions = torch.argmax(masked_logits, dim=1)
        random_actions = []
        for i in range(batch_size):
            valid_indices = (action_mask[i] > 0).nonzero(as_tuple=True)[0]
            rand_idx = torch.randint(0, len(valid_indices), (1,))
            random_actions.append(valid_indices[rand_idx])
        random_actions = torch.stack(random_actions).squeeze(-1)

        explore_mask = torch.rand(batch_size) < epsilon
        actions = torch.where(explore_mask, random_actions, greedy_actions)
        outs[Columns.ACTION_DIST_INPUTS] = masked_logits
        outs["actions"] = actions
        return outs

This code ensures only valid actions are sampled and applies epsilon-greedy exploration. For the full context, see action_masking_rlm.py.

Sources:

Hint: Mention @RunLLM in the post for followups.