Custom Env with PPO, agent picking same action in testing

1. Severity of the issue: (select one)
High: Completely blocks me.

2. Environment:

  • Ray version: 2.41.0
  • Python version: 3.12.8
  • OS: Windows 2022Server
  • Other libs/tools (if relevant): OpenAI Gymnasium

3. What happened vs. what you expected:

  • Expected: The agent picks different actions and covers the circle in ellipses in inference (testing).
  • Actual: The agent does not learn and picks the same action again in testing.

Hi, I have created a custom environment which consists of two circles, a bigger one and a smaller one inside it. The aim of my agent is to cover the smaller circle in ellipses. The only rule it has is not to exceed the bounds of the larger circle. My observation space is:

 self.smaller_grid_size = (138, 138) #smaller grid size
gym.spaces.Box(low=0, high=255, shape=(self.smaller_grid_size[0], self.smaller_grid_size[1], 3), dtype=np.uint8)  

and my action space is:

        self.minimum_grid_x = 111
        self.minimum_grid_y = 101

        # Calculate the dimensions of the coarse grid
        self.coarse_width = (self.minimum_grid_x + self.grid_spacing - 1) // self.grid_spacing
        self.coarse_height = (self.minimum_grid_y + self.grid_spacing - 1) // self.grid_spacing

gym.spaces.Tuple((
            gym.spaces.Discrete(self.coarse_width),  # x coordinates limited to grid width
            gym.spaces.Discrete(self.coarse_height)  # y coordinates limited to grid height
        ))

Here are some visuals (a few random episodes) for reference:

multi_episode_packing

This is my configuration file:

import os
import ray
import gymnasium
import ellipse_image_envs
import simple_env_ellipse
from ellipse_image_envs import EllipsePackingRGB
from simple_env_ellipse import EllipsePackingSimple
import torch
import pandas as pd


# Add PyTorch CUDA memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Empty CUDA cache before starting
if torch.cuda.is_available():
    torch.cuda.empty_cache()

import CNN_RL
import CNN_simple
from CNN_simple import CustomCNN_simple

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.models.catalog import Catalog

from ray.tune.logger import JsonLoggerCallback, CSVLoggerCallback, TBXLoggerCallback
from ray.rllib.examples.envs.env_rendering_and_recording import EnvRenderCallback
from ray import tune

from custom_training_metric_callback import EpisodeEndMetricsCallback


os.environ["TORCH_DISTRIBUTED_BACKEND"] = "gloo"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


class CustomCatalog(Catalog):
    def build(self):
        return {
            "model": CustomCNN_simple(
                observation_space=self.observation_space,
                action_space=self.action_space,
                model_config=self.model_config,
            )
        }


def main():
    # Initialize Ray with explicit cluster configuration
    ray.init(address="auto")  # For cluster mode

    env = EllipsePackingSimple()
    metrics_callback = EpisodeEndMetricsCallback()

    # Define a function that returns your callback instance
    def get_callbacks():
        return metrics_callback

    spec = RLModuleSpec(
        module_class=CustomCNN_simple,
        catalog_class=CustomCatalog,
        inference_only=False,
        observation_space=env.observation_space,
        action_space=env.action_space,
        model_config={
                    "conv_filters":  [
                [32, [5, 5], 2, "same"],
                [64, [5, 5], 2, "same"],
                [128, [3, 3], 1, "same"]
            ],
            "fcnet_hiddens": [512],
            "fcnet_activation": "relu",
        },
    )

    config = (
        PPOConfig()
        .environment(EllipsePackingSimple)
        .callbacks(get_callbacks)
        .resources(num_gpus=1)
        .framework("torch")
        .debugging(log_level="DEBUG")
        .api_stack(
            enable_rl_module_and_learner=True,
            enable_env_runner_and_connector_v2=True
        )
        .learners(
            num_learners=0,
            num_gpus_per_learner=0.5,
            local_gpu_idx=0,
        )
        .env_runners(
            env_runner_cls=SingleAgentEnvRunner,
            num_env_runners=0,
            sample_timeout_s=256,
            rollout_fragment_length=256,
            num_gpus_per_env_runner=0.2,
            explore=True
            
        )
        .rl_module(
            rl_module_spec=RLModuleSpec(
                module_class=CustomCNN_simple,
                catalog_class=CustomCatalog,
                inference_only=False,
                observation_space=env.observation_space,
                action_space=env.action_space,
                model_config={
                            "conv_filters":  [
                    [32, [5, 5], 2, "same"],
                    [64, [5, 5], 2, "same"],
                    [128, [3, 3], 1, "same"]
                ],
                    "fcnet_hiddens": [512],  # Add another FC layer for better feature integration
                    "fcnet_activation": "relu",
                },
            ),
        )
        .training(
            gamma=0.99,  # Increase from 0.95 to value future rewards more
            lr=0.0001,   # Keep same learning rate
            kl_coeff=0.2,  # Increase from 0.01 to prevent policy collapse
            train_batch_size_per_learner=256,  # Keep same to avoid memory issues
            num_sgd_iter=10,  # Increase from 5 for more learning per batch
            entropy_coeff=0.05,  # Add entropy coefficient to encourage exploration
            clip_param=0.2,  # Standard PPO clipping parameter
            vf_clip_param=10.0,  # Limit value function updates
            use_gae=True,  # Enable Generalized Advantage Estimation
            lambda_=0.95,  # GAE parameter
        )
              
    )


    algo = config.build_algo()
 

    checkpoint_dir = os.path.join(os.getcwd(), "checkpoints")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Update the metrics DataFrame definition to include fill percentage
    metrics_df = pd.DataFrame(columns=[
        "iteration", "total_loss", "policy_loss", "vf_loss", "entropy", "mean_kl_loss",
        "cpu_util_percent", "gpu_util_percent0",
        "num_env_steps_sampled", "num_agent_steps_sampled", "num_module_steps_sampled",
        "time_since_restore", "sample"
    ])

    for i in range(3000):
        result = algo.train()
        result.pop("config")
        
        # Extract available metrics from the result
        iteration = result['training_iteration']

        
        # Learner metrics
        total_loss = result['learners']['default_policy'].get('total_loss', 0)
        policy_loss = result['learners']['default_policy'].get('policy_loss', 0)
        vf_loss = result['learners']['default_policy'].get('vf_loss', 0)
        entropy = result['learners']['default_policy'].get('entropy', 0)
        # Use 'mean_kl_loss' instead of 'kl'
        mean_kl_loss = result['learners']['default_policy'].get('mean_kl_loss', 0)
        
        # Performance metrics
        cpu_util_percent = result['perf'].get('cpu_util_percent', 0)
        gpu_util_percent0 = result['perf'].get('gpu_util_percent0', 0)
        
        # Environment runner metrics
        num_env_steps_sampled = result['env_runners'].get('num_env_steps_sampled', 0)
        num_agent_steps_sampled = result['env_runners']['num_agent_steps_sampled'].get('default_agent', 0)
        num_module_steps_sampled = result['env_runners']['num_module_steps_sampled'].get('default_policy', 0)
        
        # Other metrics
        time_since_restore = result.get('time_since_restore', 0)
        sample = result['env_runners'].get('sample', 0)

        # Create a new row with extracted metrics
        new_row = {
            "iteration": iteration,
            "total_loss": total_loss,
            "policy_loss": policy_loss,
            "vf_loss": vf_loss,
            "entropy": entropy,
            "mean_kl_loss": mean_kl_loss,  # Changed from 'kl' to 'mean_kl_loss'
            "cpu_util_percent": cpu_util_percent,
            "gpu_util_percent0": gpu_util_percent0,
            "num_env_steps_sampled": num_env_steps_sampled,
            "num_agent_steps_sampled": num_agent_steps_sampled,
            "num_module_steps_sampled": num_module_steps_sampled,
            "time_since_restore": time_since_restore,
            "sample": sample
        }




            # Add the row to the DataFrame
        metrics_df = pd.concat([metrics_df, pd.DataFrame([new_row])], ignore_index=True)
        
        # Save current metrics after each iteration (overwrites previous version)
        metrics_df.to_csv("training_metrics_simple_ver_latest.csv", index=False)
        
        # Print progress
        print(f"Completed iteration {i}, DataFrame now has {len(metrics_df)} rows")
        
        
    # Save final version once training is complete
    metrics_df.to_csv("training_metrics_simple_ver_final.csv", index=False)
    custom_metrics_df = metrics_callback.get_episode_df()
    custom_metrics_df.to_csv("final_results.csv")
    print("Training complete. Final metrics saved.")

if __name__ == "__main__":
    main()

Here is my CNN code (based off of the atari example):

from typing import Any, Dict, Optional
import ray
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import gymnasium as gym
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.rl_module.apis import (
    TargetNetworkAPI,
    ValueFunctionAPI,
    #TARGET_NETWORK_ACTION_DIST_INPUTS,
)
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.misc import (
    normc_initializer,
    same_padding,
    valid_padding,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI

@PublicAPI
def try_import_torch(error: bool = False):
    """Tries importing torch and returns the module (or None).

    Args:
        error: Whether to raise an error if torch cannot be imported.

    Returns:
        Tuple consisting of the torch- AND torch.nn modules.

    Raises:
        ImportError: If error=True and PyTorch is not installed.
    """
    if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
        logger.warning("Not importing PyTorch for test purposes.")
        return _torch_stubs()

    try:
        import torch
        import torch.nn as nn

        return torch, nn
    except ImportError:
        if error:
            raise ImportError(
                "Could not import PyTorch! RLlib requires you to "
                "install at least one deep-learning framework: "
                "`pip install [torch|tensorflow|jax]`."
            )
        return _torch_stubs()
    
torch, nn = try_import_torch()




class CustomCNN_simple(TorchRLModule, ValueFunctionAPI, TargetNetworkAPI):
 
     @override(TorchRLModule)
     def setup(self):
         """Use this method to create all the model components that you require.
 
         Feel free to access the following useful properties in this class:
         - `self.model_config`: The config dict for this RLModule class,
         which should contain flxeible settings, for example: {"hiddens": [256, 256]}.
         - `self.observation|action_space`: The observation and action space that
         this RLModule is subject to. Note that the observation space might not be the
         exact space from your env, but that it might have already gone through
         preprocessing through a connector pipeline (for example, flattening,
         frame-stacking, mean/std-filtering, etc..).
         """
         #print(f"Model Observation Space: {self.observation_space}")
         #print(f"Model Action Space: {self.action_space}")
         #print(f"Model running on GPU: {torch.cuda.is_available()}")
         self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         #print(f"Using device: {self.device}")
 
 
         # Get the CNN stack config from our RLModuleConfig's (self.config)
         # `model_config` property:
         conv_filters = self.model_config.get("conv_filters")
         # Default CNN stack with 3 layers:
         if conv_filters is None:
                    conv_filters = [
    [32, [5, 5], 2, "same"],   # Change stride to 2
    [64, [3, 3], 2, "same"],   # Change stride to 2
    [128, [3, 3], 1, "same"],  
    [128, [3, 3], 1, "same"],
]

        # Build the CNN layers.
         layers = []
 
         # Add user-specified hidden convolutional layers first
         width, height, in_depth = self.observation_space.shape
         in_size = [width, height]

         for filter_specs in conv_filters:
             if len(filter_specs) == 4:
                 out_depth, kernel_size, strides, padding = filter_specs
             else:
                 out_depth, kernel_size, strides = filter_specs
                 padding = "same"
 
             # Pad like in tensorflow's SAME mode.
             if padding == "same":
                 padding_size, out_size = same_padding(in_size, kernel_size, strides)
                 layers.append(nn.ZeroPad2d(padding_size))
             else:
                 out_size = valid_padding(in_size, kernel_size, strides)


             layer = nn.Conv2d(in_depth, out_depth, kernel_size, strides, bias=True)
             nn.init.xavier_uniform_(layer.weight)
             nn.init.zeros_(layer.bias)
             layers.append(layer)
        
             layers.append(nn.BatchNorm2d(out_depth))

             layers.append(nn.LeakyReLU(0.1))

 
             in_size = out_size
             in_depth = out_depth

             layers = [layer.to(self.device) for layer in layers]
 
 
 
         self._base_cnn_stack = nn.Sequential(*layers)
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)


        # Define separate logits layers for each discrete action space
         if isinstance(self.action_space, gym.spaces.Tuple):
            discrete_action_sizes = [
                space.n if isinstance(space, gym.spaces.Discrete) else 0
                for space in self.action_space.spaces
            ]
            num_discrete_actions_1 = discrete_action_sizes[0]  
            num_discrete_actions_2 = discrete_action_sizes[1] 


        # Spatial awareness logits - using 1×1 convolutions maintains spatial correlation
         self._logits1 = nn.Sequential(
            nn.ZeroPad2d(same_padding(in_size, 1, 1)[0]),
            nn.Conv2d(in_depth, num_discrete_actions_1, kernel_size=1, stride=1, bias=True)
         ).to(self.device)
        
         self._logits2 = nn.Sequential(
            nn.ZeroPad2d(same_padding(in_size, 1, 1)[0]),
            nn.Conv2d(in_depth, num_discrete_actions_2, kernel_size=1, stride=1, bias=True)
         ).to(self.device)
        
         # Initialize with smaller values for better exploration
         for logits_layer in [self._logits1[1], self._logits2[1]]:
            nn.init.xavier_uniform_(logits_layer.weight, gain=0.01).to(self.device)
            nn.init.zeros_(logits_layer.bias).to(self.device)
        
         # Value function layer
         self._values = nn.Sequential(
            nn.Linear(in_depth, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1)
         ).to(self.device)
        
            # Make sure all components are on the device
         self.to(self.device)
         # IMPORTANT: Move the entire model to the device ONCE at the end
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)
         self._logits1 = self._logits1.to(self.device)
         self._logits2 = self._logits2.to(self.device)
         self._values = self._values.to(self.device)




     @override(TorchRLModule)
     def _forward(self, batch, **kwargs):
         # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
         # Print device of the input tensors
         batch = {k: v.to(self.device) for k, v in batch.items()}
         #print(f"Batch device in forward: {batch[Columns.OBS].device}")
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)
         #print(next(self._base_cnn_stack.parameters()).device)
 
 
 
         _, logits = self._compute_embeddings_and_logits(batch)
         logits = logits.to(self.device)
         #print(f"logits device in forward: {logits.device}")
         # Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
         return {
             Columns.ACTION_DIST_INPUTS: logits,
         }
 
     @override(TorchRLModule)
     def _forward_train(self, batch, **kwargs):
         # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
         #batch = {k: v.to(self.device) for k, v in batch.items()}
         batch[Columns.OBS] = batch[Columns.OBS].to(self.device) 
         #print(f"Batch device in forward train: {batch[Columns.OBS].device}")
         embeddings, logits = self._compute_embeddings_and_logits(batch)
         embeddings= embeddings.to(self.device)
         logits = logits.to(self.device)
 
         #print(f"logits device in forward train: {logits.device}")
         #print(f"embeddings device in forward train: {embeddings.device}")
         # Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
         #print(batch)
         for k, v in batch.items():
             if isinstance(v, torch.Tensor):
                 batch[k] = v.to(self.device)  # Move single tensor
             elif isinstance(v, tuple):  
                 batch[k] = tuple(item.to(self.device) if isinstance(item, torch.Tensor) else item for item in v) 
 
 
         return {
             Columns.ACTION_DIST_INPUTS: logits,
             Columns.EMBEDDINGS: embeddings,
         }
 
     @override(TargetNetworkAPI)
     def make_target_networks(self) -> None:
         # Create target networks for each of the individual logit layers
         #self._target_base_cnn_stack = make_target_network(self._base_cnn_stack).to("cuda")
         self._target_base_cnn_stack = make_target_network(self._base_cnn_stack).to(self.device)


         # Create target networks for each individual logit layer
         self._target_logits1 = make_target_network(self._logits1).to(self.device)
         self._target_logits2 = make_target_network(self._logits2).to(self.device)
         #self._target_logits3 = make_target_network(self._logits3).to(self.device)
 
 
     @override(TargetNetworkAPI)
     def get_target_network_pairs(self):
        pairs = [
         # Pair for the CNN base layers
         (self._base_cnn_stack, self._target_base_cnn_stack),
         
         # Pairs for the individual logit layers
         (self._logits1, self._target_logits1),
         (self._logits2, self._target_logits2)
         ]
 
             
        return pairs

     @override(TargetNetworkAPI)
     def forward_target(self, batch, **kw):
 
         # Normalize and permute the observations for CNN input
         obs = batch[Columns.OBS].float().to(self.device) / 255.0  # Convert uint8 to float32 and normalize
        
         obs = obs.permute(0, 3, 1, 2)  # Change shape for CNN
 
 
         # Pass the observations through the base CNN stack to get embeddings
         embeddings = self._target_base_cnn_stack(obs)
         embeddings = embeddings.to(self.device)


         # Compute the logits from each individual logit layer
         logits1 = self._target_logits1(embeddings)
         logits2 = self._target_logits2(embeddings)
         #logits3 = self._target_logits3(embeddings)
 
 
         logits1 = logits1.to(self.device)
         logits2 = logits2.to(self.device)
         #logits3 = logits3.to(self.device)
 
 
         # Concatenate the logits from each layer
         logits = torch.cat([logits1, logits2], dim=-1)
 
         # Return the logits as the action distribution inputs
         return {"target_network_action_dist_inputs": torch.squeeze(logits, dim=[-1, -2])}


     @override(ValueFunctionAPI)
     def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
    ) -> TensorType:
         if embeddings is None:
            obs = batch[Columns.OBS].float().to(self.device) / 255.0 
            self._base_cnn_stack = self._base_cnn_stack.to(self.device)
            obs = obs.permute(0, 3, 1, 2)
            embeddings = self._base_cnn_stack(obs)
            embeddings = embeddings.to(self.device)
        
        # Use channel-wise attention
         b, c, h, w = embeddings.shape
        
        # Spatial feature extraction - pool height and width separately
        # This preserves directional spatial information
         h_pooled = torch.mean(embeddings, dim=3)  # [b, c, h]
         w_pooled = torch.mean(embeddings, dim=2)  # [b, c, w]
        
        # Calculate attention weights for height and width
         h_weights = F.softmax(h_pooled.mean(dim=1, keepdim=True), dim=2)  # [b, 1, h]
         w_weights = F.softmax(w_pooled.mean(dim=1, keepdim=True), dim=2)  # [b, 1, w]
        
        # Apply attention to get weighted features
         h_features = torch.bmm(h_pooled, h_weights.transpose(1, 2)).squeeze(2)  # [b, c]
         w_features = torch.bmm(w_pooled, w_weights.transpose(1, 2)).squeeze(2)  # [b, c]
        
        # Combine features
         combined_features = h_features + w_features  # Element-wise addition
        
         # Pass through value function
         self._values = self._values.to(self.device)
         return self._values(combined_features).squeeze(-1)
 
     def _compute_embeddings_and_logits(self, batch):
 
        obs = batch[Columns.OBS].float().to(self.device) / 255.0  # Convert uint8 to float32 and normalize
        self._base_cnn_stack = self._base_cnn_stack.to(self.device)
        obs = obs.permute(0, 3, 1, 2)  # Change shape for CNN
 
         # Pass the observation through the CNN layers to get the embeddings
        embeddings = self._base_cnn_stack(obs)

        # Ensure embeddings are on the correct device
        embeddings = embeddings.to(self.device)
        self._logits1 =  self._logits1.to(self.device)
        self._logits2 =  self._logits2.to(self.device)
         #self._logits3 = self._logits3.to(self.device)

        logits1 = self._logits1(embeddings)
        logits2 = self._logits2(embeddings)

     
        # Reshape the logits from [batch, channels, height, width] to [batch, channels]
        # while preserving spatial information
        b, c1, h, w = logits1.shape
        b, c2, h, w = logits2.shape

        x_action_size = self.action_space[0].n  # Should be 23
        y_action_size = self.action_space[1].n  # Should be 21
        
        # For X coordinates - use adaptive avg pooling to maintain spatial structure along width
        # This maintains the horizontal position information
        x_pooled = F.adaptive_avg_pool2d(logits1, (1, x_action_size))  # [b, c1, 1, x_action_size]
        x_pooled = x_pooled.squeeze(2)  # [b, c1, x_action_size]
        x_logits = torch.sum(x_pooled, dim=1)  # [b, x_action_size]
        
        # For Y coordinates - use adaptive avg pooling to maintain spatial structure along height
        # This maintains the vertical position information
        y_pooled = F.adaptive_avg_pool2d(logits2, (y_action_size, 1))  # [b, c2, y_action_size, 1]
        y_pooled = y_pooled.squeeze(3)  # [b, c2, y_action_size]
        y_logits = torch.sum(y_pooled, dim=1)  # [b, y_action_size]
        
        # Combine both coordinate logits
        logits = torch.cat([x_logits, y_logits], dim=1)  # [b, x_action_size+y_action_size]
        
        print(f"Logits shape: {logits.shape}")
        print(f"Embeddings shape: {embeddings.shape}")
        return  embeddings, logits
            

No matter how many epochs I train for or if I change the parameters of my CNN, my agent just does not learn. In my training, I have the explore parameter = True, but when I test, I switch this to false. The agent just picks the same action again and again. One of the termination criteria is that if an action is picked twice, it will terminate, so it just terminates.

Here is my test code:

import os
import torch
import numpy as np
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core import (
    COMPONENT_ENV_RUNNER,
    COMPONENT_ENV_TO_MODULE_CONNECTOR,
    COMPONENT_MODULE_TO_ENV_CONNECTOR,
    COMPONENT_LEARNER_GROUP,
    COMPONENT_LEARNER,
    COMPONENT_RL_MODULE,
    DEFAULT_MODULE_ID,
)
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
from ray.rllib.connectors.module_to_env import ModuleToEnvPipeline
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from simple_env_ellipse import EllipsePackingSimple

# Create environment
env = EllipsePackingSimple()
print(f"Environment created with observation shape {env.observation_space.shape} and action space {env.action_space}")

# Path to your checkpoint
checkpoint_path = r"C:\Users\scmmw\Documents\GitHub\DRL\src\drl\checkpoints"

# Load the RL Module (policy network)
print("Loading RLModule from checkpoint...", end="")
try:
    rl_module = RLModule.from_checkpoint(
        os.path.join(
            checkpoint_path,
            COMPONENT_LEARNER_GROUP,
            COMPONENT_LEARNER,
            COMPONENT_RL_MODULE,
            DEFAULT_MODULE_ID,
        )
    )
    print(" success!")
except Exception as e:
    print(f" failed: {e}")
    raise e

# Load the env-to-module pipeline
print("Loading env-to-module connector...", end="")
try:
    env_to_module = EnvToModulePipeline.from_checkpoint(
        os.path.join(
            checkpoint_path,
            COMPONENT_ENV_RUNNER,
            COMPONENT_ENV_TO_MODULE_CONNECTOR,
        )
    )
    print(" success!")
except Exception as e:
    print(f" failed: {e}")


# Load the module-to-env pipeline
print("Loading module-to-env connector...", end="")
try:
    module_to_env = ModuleToEnvPipeline.from_checkpoint(
        os.path.join(
            checkpoint_path,
            COMPONENT_ENV_RUNNER,
            COMPONENT_MODULE_TO_ENV_CONNECTOR,
        )
    )
    print(" success!")
except Exception as e:
    print(f" failed: {e}")


# Set up for episodes
num_episodes = 0
max_episodes = 1  # Number of episodes to run
exploration = False  # Whether to use exploration during inference

print("\n--- Starting episodes ---")

# Initialize episode
obs, _ = env.reset()
episode = SingleAgentEpisode(
    observations=[obs],
    observation_space=env.observation_space,
    action_space=env.action_space,
)

action_history = []  # Track action history
step_count = 0  # Track steps in current episode
total_steps = 0  # Track total steps across all episodes

while num_episodes < max_episodes:
    step_count += 1
    total_steps += 1
    
    # Process through env-to-module pipeline
    shared_data = {}
    input_dict = env_to_module(
        episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes
        rl_module=rl_module,
        explore=exploration,
        shared_data=shared_data,
    )
    
    # Forward through RLModule
    if not exploration:
        rl_module_out = rl_module.forward_inference(input_dict)
    else:
        rl_module_out = rl_module.forward_exploration(input_dict)

        # INSERT THE CODE HERE - After forwarding through RLModule but before module-to-env pipeline
    action_probs_x = torch.nn.functional.softmax(
        rl_module_out[Columns.ACTION_DIST_INPUTS][:, :env.action_space[0].n], dim=-1
    )
    action_probs_y = torch.nn.functional.softmax(
        rl_module_out[Columns.ACTION_DIST_INPUTS][:, env.action_space[0].n:], dim=-1
    )

    # Get top 3 most likely actions for x and y
    top_k = 10
    top_x_probs, top_x_indices = torch.topk(action_probs_x[0], k=top_k)
    top_y_probs, top_y_indices = torch.topk(action_probs_y[0], k=top_k)

    print("\nTop X probabilities:")
    for i in range(top_k):
        print(f"  X={top_x_indices[i].item()}: {top_x_probs[i].item():.4f}")

    print("Top Y probabilities:")
    for i in range(top_k):
        print(f"  Y={top_y_indices[i].item()}: {top_y_probs[i].item():.4f}")

    
    
    # Process through module-to-env pipeline
    to_env = module_to_env(
        batch=rl_module_out,
        episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes
        rl_module=rl_module,
        explore=exploration,
        shared_data=shared_data,
    )
    
    # Extract action (handle batched output)
    batched_action = to_env.pop(Columns.ACTIONS)
    #print(f"Raw action output: {batched_action}")
    action = (batched_action[0][0], batched_action[1][0])
    #print(f"Action chosen: {action}")
    action_history.append(action)
    
    # Take step in environment
    obs, reward, terminated, truncated, _ = env.step(action)
    
    # Update episode with new step
    episode.add_env_step(
        obs,
        action,
        reward,
        terminated=terminated,
        truncated=truncated,
        # Extract any extra outputs from module-to-env
        extra_model_outputs={k: v[0] for k, v in to_env.items()},
    )
    
    # Print step info
    print(f"Step {step_count}: Action {action}, Reward {reward:.4f}, Return {episode.get_return():.4f}")
    
    # Count repeats of current action in history
    action_repeats = action_history.count(action)
    if action_repeats > 1:
        print(f"Warning: Action {action} has been chosen {action_repeats} times")
    
    # Render environment
    env.render()
    
    # Check if episode is done
    if episode.is_done:
        print(f"\nEpisode {num_episodes + 1} complete after {step_count} steps")
        print(f"Final return: {episode.get_return():.4f}")
        
        # Analyze action diversity for this episode
        unique_actions = set(action_history)
        print(f"Used {len(unique_actions)} unique actions out of {step_count} steps")
        if action_history:
            most_common = max(set(action_history), key=action_history.count)
            count = action_history.count(most_common)
            print(f"Most common action: {most_common} (used {count} times, {count/step_count:.1%} of steps)")
        
        # Reset for next episode
        num_episodes += 1
        if num_episodes < max_episodes:
            obs, _ = env.reset()
            episode = SingleAgentEpisode(
                observations=[obs],
                observation_space=env.observation_space,
                action_space=env.action_space,
            )
            action_history = []  # Reset action history for new episode
            step_count = 0  # Reset step count for new episode

print(f"\nCompleted {num_episodes} episodes with {total_steps} total steps")

I need the policy to be deterministic in testing, which is why the parameter is False. However, it seems to not learn effectively I have tried different reward/penalty strategies but no matter what I do it just fails. My current strategy is to give -1 for every ellipse placed and -10 for invalid placement (along with termination). Any help would be appreciated!

Hey @manisha-waterston,

My question is, do you ever reward the agent other than -1 and -10? I can see a world where, if the ellipses can be 15 or 20, then the least amount of negative reward the agent will receive is failing twice in a row to get -10. As opposed to have -15 for placing 15 ellipses.

I would try to give +1 for every ellipse placed within the circle and a -10 for outside with a termination. Or try +1 for every ellipse placed and -10 for each one on the outside, but do not terminate the episode. The agent will minimize the negative reward while maximizing the positive reward signal. This and it will get more training with longer episodes. Reward shaping is always a massaging of different techniques.

Let me know if this helps,

Tyler