Custom Env (PPO + Action Masking) GPU and CPU mismatch error

1. Severity of the issue: (select one)
High: Completely blocks me.

2. Environment:

  • Ray version: 2.41.0
  • Python version: 3.12.8
  • OS: Windows 2022Server
  • Other libs/tools (if relevant): OpenAI Gymnasium

3. What happened vs. what you expected:

  • Expected: Everything is on CUDA
  • Actual: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I recently tried to implement Action Masking for my discrete action space using the example ray provides. I seem to be having issues with my different devices, as I have access to 1 GPU. I want the model to run on the GPU. From what I think, it could be that the batch values for the learners portion are on CPU, I have tried to move them to cuda but cant seem to get it.

Here is the a bit more of the error:


File c:\Users\''\AppData\Local\anaconda3\envs\myenvrl\Lib\site-packages\ray\rllib\core\learner\learner.py:924, in Learner.compute_losses(self, fwd_out, batch)
    916         loss = module.compute_self_supervised_loss(
    917             learner=self,
    918             module_id=module_id,
   (...)
    921             fwd_out=module_fwd_out,
    922         )
    923     else:
--> 924         loss = self.compute_loss_for_module(
    925             module_id=module_id,
    926             config=self.config.get_config_for_module(module_id),
    927             batch=module_batch,
    928             fwd_out=module_fwd_out,
    929         )
    930     loss_per_module[module_id] = loss
    932 return loss_per_module

File c:\Users\''\AppData\Local\anaconda3\envs\myenvrl\Lib\site-packages\ray\rllib\algorithms\ppo\torch\ppo_torch_learner.py:75, in PPOTorchLearner.compute_loss_for_module(self, module_id, config, batch, fwd_out)
     66 # TODO (sven): We should ideally do this in the LearnerConnector (separation of
     67 #  concerns: Only do things on the EnvRunners that are required for computing
     68 #  actions, do NOT do anything on the EnvRunners that's only required for a
     69 #   training update).
     70 prev_action_dist = action_dist_class_exploration.from_logits(
     71     batch[Columns.ACTION_DIST_INPUTS]
     72 )
     74 logp_ratio = torch.exp(
---> 75     curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
     76 )
     78 # Only calculate kl loss if necessary (kl-coeff > 0.0).
     79 if config.use_kl_loss:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I have checked with print statements in my custom model, and everything seems to be on the CUDA, so I don’t understand why this is happening.

Here is my custom model

from typing import Any, Dict, Optional
import ray
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import gymnasium as gym
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.rl_module.apis import (
    TargetNetworkAPI,
    ValueFunctionAPI,
    #TARGET_NETWORK_ACTION_DIST_INPUTS,
)
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.misc import (
    normc_initializer,
    same_padding,
    valid_padding,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
# Monkey patch TorchDistribution.logp to your patched_logp


@PublicAPI
def try_import_torch(error: bool = False):
    """Tries importing torch and returns the module (or None).

    Args:
        error: Whether to raise an error if torch cannot be imported.

    Returns:
        Tuple consisting of the torch- AND torch.nn modules.

    Raises:
        ImportError: If error=True and PyTorch is not installed.
    """
    if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
        logger.warning("Not importing PyTorch for test purposes.")
        return _torch_stubs()

    try:
        import torch
        import torch.nn as nn

        return torch, nn
    except ImportError:
        if error:
            raise ImportError(
                "Could not import PyTorch! RLlib requires you to "
                "install at least one deep-learning framework: "
                "`pip install [torch|tensorflow|jax]`."
            )
        return _torch_stubs()
    
torch, nn = try_import_torch()




class CustomCNN_simple(TorchRLModule, ValueFunctionAPI, TargetNetworkAPI):
 
     @override(TorchRLModule)
     def setup(self):
         """Use this method to create all the model components that you require.
 
         Feel free to access the following useful properties in this class:
         - `self.model_config`: The config dict for this RLModule class,
         which should contain flxeible settings, for example: {"hiddens": [256, 256]}.
         - `self.observation|action_space`: The observation and action space that
         this RLModule is subject to. Note that the observation space might not be the
         exact space from your env, but that it might have already gone through
         preprocessing through a connector pipeline (for example, flattening,
         frame-stacking, mean/std-filtering, etc..).
         """
         #print(f"Model Observation Space: {self.observation_space}")
         #print(f"Model Action Space: {self.action_space}")
         #print(f"Model running on GPU: {torch.cuda.is_available()}")
         self.device = torch.device("cuda:0")
         #print(f"Using device: {self.device}")
 
 
         # Get the CNN stack config from our RLModuleConfig's (self.config)
         # `model_config` property:
         conv_filters = self.model_config.get("conv_filters")
         # Default CNN stack with 3 layers:
         if conv_filters is None:
                    conv_filters = [
    [32, [5, 5], 2, "same"],   # Change stride to 2
    [64, [3, 3], 2, "same"],   # Change stride to 2
    [128, [3, 3], 1, "same"],  
    [128, [3, 3], 1, "same"],
]

        # Build the CNN layers.
         layers = []
 
         if isinstance(self.observation_space, gym.spaces.Dict):
            # Extract the actual visual observation space
            visual_obs_space = self.observation_space["observations"]
            width, height, in_depth = visual_obs_space.shape
         else:
            # Fallback for regular Box observation space
            width, height, in_depth = self.observation_space.shape
         in_size = [width, height]




         for filter_specs in conv_filters:
             if len(filter_specs) == 4:
                 out_depth, kernel_size, strides, padding = filter_specs
             else:
                 out_depth, kernel_size, strides = filter_specs
                 padding = "same"
 
             # Pad like in tensorflow's SAME mode.
             if padding == "same":
                 padding_size, out_size = same_padding(in_size, kernel_size, strides)
                 layers.append(nn.ZeroPad2d(padding_size))
             else:
                 out_size = valid_padding(in_size, kernel_size, strides)


             layer = nn.Conv2d(in_depth, out_depth, kernel_size, strides, bias=True)
             nn.init.xavier_uniform_(layer.weight)
             nn.init.zeros_(layer.bias)
             layers.append(layer)
        
             layers.append(nn.BatchNorm2d(out_depth))

             layers.append(nn.LeakyReLU(0.1))

 
             in_size = out_size
             in_depth = out_depth

             layers = [layer.to(self.device) for layer in layers]
 
 
 
         self._base_cnn_stack = nn.Sequential(*layers)
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)


         

         n_actions = self.action_space.n  # For Discrete space

         # Use a 1x1 conv + global avg pool for logits head (preserves spatial info)
  
         self._logits= nn.Sequential(
            nn.Conv2d(in_depth, n_actions, kernel_size=1),
            nn.AdaptiveAvgPool2d((1, 1)),  # [batch, n_actions, 1, 1]
            nn.Flatten(),                  # [batch, n_actions]
         ).to(self.device)

        # Value head stays as before (uses attention pooling)
         self._values = nn.Sequential(
            nn.Linear(in_depth, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1)
         ).to(self.device)
        
            # Make sure all components are on the device
         self.to(self.device)
         # IMPORTANT: Move the entire model to the device ONCE at the end
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)
         self._logits = self._logits.to(self.device)
         self._values = self._values.to(self.device)

    # Also add this method to ensure proper device handling:
     def move_to_device(self, batch):
        """Ensure all tensors in batch are on the correct device"""
        model_device = next(self.parameters()).device
        
        def move_tensor_to_device(obj):
            if hasattr(obj, 'to') and hasattr(obj, 'device'):
                return obj.to(model_device)
            elif isinstance(obj, dict):
                return {k: move_tensor_to_device(v) for k, v in obj.items()}
            elif isinstance(obj, (list, tuple)):
                return type(obj)(move_tensor_to_device(item) for item in obj)
            else:
                return obj
        
        return move_tensor_to_device(batch)



     @override(TorchRLModule)
     def _forward(self, batch, **kwargs):
         # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
         # Print device of the input tensors
         #batch = {k: v.to(self.device) for k, v in batch.items()}
         batch = self.move_to_device(batch)
         #print(f"Batch device in forward: {batch[Columns.OBS].device}")
         self._base_cnn_stack = self._base_cnn_stack.to(self.device)
         #print(next(self._base_cnn_stack.parameters()).device)
 
 
 
         _, logits = self._compute_embeddings_and_logits(batch)
         logits = logits.to(self.device)

                # Debug: Check all relevant batch items are on correct device
         #print(f"Logits device: {logits.device}")
         #if Columns.ACTIONS in batch:
          #  print(f"Actions device: {batch[Columns.ACTIONS].device}")
         ##f Columns.ACTION_LOGP in batch:
         #   print(f"Action_logp device: {batch[Columns.ACTION_LOGP].device}")

        # Ensure actions and action_logp are on the correct device
         if Columns.ACTIONS in batch:
            batch[Columns.ACTIONS] = batch[Columns.ACTIONS].to(self.device)
         if Columns.ACTION_LOGP in batch:
            batch[Columns.ACTION_LOGP] = batch[Columns.ACTION_LOGP].to(self.device)


         # Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
         return {
             Columns.ACTION_DIST_INPUTS: logits,
         }
 
     @override(TorchRLModule)
     def _forward_train(self, batch, **kwargs):
        print("\n=== DEBUG: _forward_train ===")
        
        # CRITICAL: Ensure ALL batch tensors are on the correct device
        batch = self.move_to_device(batch)
        
        # Debug batch contents after device move
        print("Batch after move_to_device:")
        for key, value in batch.items():
            if hasattr(value, 'device'):
                print(f"  {key}: device={value.device}")
            elif isinstance(value, dict):
                print(f"  {key} (dict):")
                for k, v in value.items():
                    if hasattr(v, 'device'):
                        print(f"    {k}: device={v.device}")
        
        # Extra safety: Move actions to the correct device explicitly
        if Columns.ACTIONS in batch:
            actions = batch[Columns.ACTIONS]
            print(f"Actions before fix: device={actions.device}, dtype={actions.dtype}")
            
            # Ensure actions are on the same device as the model
            model_device = next(self.parameters()).device
            if actions.device != model_device:
                print(f"Moving actions from {actions.device} to {model_device}")
                batch[Columns.ACTIONS] = actions.to(model_device)
            
            # Ensure actions are the correct dtype (long/int for categorical)
            if batch[Columns.ACTIONS].dtype not in [torch.long, torch.int]:
                print(f"Converting actions from {batch[Columns.ACTIONS].dtype} to torch.long")
                batch[Columns.ACTIONS] = batch[Columns.ACTIONS].long()
            
            print(f"Actions after fix: device={batch[Columns.ACTIONS].device}, dtype={batch[Columns.ACTIONS].dtype}")
        
        # Also ensure ACTION_LOGP is on correct device
        if Columns.ACTION_LOGP in batch:
            action_logp = batch[Columns.ACTION_LOGP]
            model_device = next(self.parameters()).device
            if action_logp.device != model_device:
                print(f"Moving action_logp from {action_logp.device} to {model_device}")
                batch[Columns.ACTION_LOGP] = action_logp.to(model_device)
        
        # Get model outputs
        embeddings, logits = self._compute_embeddings_and_logits(batch)
        
        # Ensure logits are on the correct device
        model_device = next(self.parameters()).device
        if logits.device != model_device:
            print(f"Moving logits from {logits.device} to {model_device}")
            logits = logits.to(model_device)
        
        print(f"Final outputs:")
        print(f"  Logits: device={logits.device}, shape={logits.shape}")
        print(f"  Embeddings: device={embeddings.device}, shape={embeddings.shape}")
        
        if Columns.ACTIONS in batch:
            print(f"  Actions: device={batch[Columns.ACTIONS].device}, shape={batch[Columns.ACTIONS].shape}")
        
        print("=== END _forward_train DEBUG ===\n")
        
        return {
            Columns.ACTION_DIST_INPUTS: logits,
            Columns.EMBEDDINGS: embeddings,
        }
 
     @override(TargetNetworkAPI)
     def make_target_networks(self) -> None:
         # Create target networks for each of the individual logit layers
         #self._target_base_cnn_stack = make_target_network(self._base_cnn_stack).to("cuda")
         self._target_base_cnn_stack = make_target_network(self._base_cnn_stack).to(self.device)


         # Create target networks for each individual logit layer
         self._target_logits = make_target_network(self._logits).to(self.device)
         #self._target_logits3 = make_target_network(self._logits3).to(self.device)
 
 
     @override(TargetNetworkAPI)
     def get_target_network_pairs(self):
        pairs = [
         # Pair for the CNN base layers
         (self._base_cnn_stack, self._target_base_cnn_stack),
         
         # Pairs for the individual logit layers
         (self._logits, self._target_logits)
         ]
 
             
        return pairs

     @override(TargetNetworkAPI)
     def forward_target(self, batch, **kw):
 
         # Normalize and permute the observations for CNN input
         batch = self.move_to_device(batch)
             # Handle Dict observation space from action masking
         if isinstance(batch[Columns.OBS], dict):
            # Extract just the visual observations, ignore action_mask
            obs = batch[Columns.OBS]["observations"].float().to(self.device) / 255.0
         else:
            # Fallback for non-dict observations
            obs = batch[Columns.OBS].float().to(self.device) / 255.0
        
         obs = obs.to(self.device)
         obs = obs.permute(0, 3, 1, 2)  # Change shape for CNN
 
 
         # Pass the observations through the base CNN stack to get embeddings
         embeddings = self._target_base_cnn_stack(obs)
         embeddings = embeddings.to(self.device)


         # Compute the logits from each individual logit layer
         logits = self._target_logits(embeddings)
   
         #logits3 = self._target_logits3(embeddings)


         logits = logits.to(self.device)
         #logits3 = logits3.to(self.device)
 
 
         # Concatenate the logits from each layer
         logits = torch.cat([logits], dim=-1)
 
         # Return the logits as the action distribution inputs
         return {"target_network_action_dist_inputs": torch.squeeze(logits, dim=[-1, -2])}


     @override(ValueFunctionAPI)
     def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
     ) -> TensorType:
        
        batch = self.move_to_device(batch)
        if embeddings is None:
            # Handle Dict observation space
            if isinstance(batch[Columns.OBS], dict):
                obs = batch[Columns.OBS]["observations"].float().to(self.device) / 255.0
            else:
                obs = batch[Columns.OBS].float().to(self.device) / 255.0
                
            obs = obs.to(self.device)
            self._base_cnn_stack = self._base_cnn_stack.to(self.device)
            obs = obs.permute(0, 3, 1, 2)
            embeddings = self._base_cnn_stack(obs)
            embeddings = embeddings.to(self.device)
        
        print(f"Embeddings device in compute_values: {embeddings.device}")
        
        # Use channel-wise attention
        b, c, h, w = embeddings.shape
        
        # Spatial feature extraction - pool height and width separately
        # This preserves directional spatial information
        h_pooled = torch.mean(embeddings, dim=3)  # [b, c, h]
        w_pooled = torch.mean(embeddings, dim=2)  # [b, c, w]
        
        # Calculate attention weights for height and width
        h_weights = F.softmax(h_pooled.mean(dim=1, keepdim=True), dim=2)  # [b, 1, h]
        w_weights = F.softmax(w_pooled.mean(dim=1, keepdim=True), dim=2)  # [b, 1, w]
        
        # Apply attention to get weighted features
        h_features = torch.bmm(h_pooled, h_weights.transpose(1, 2)).squeeze(2)  # [b, c]
        w_features = torch.bmm(w_pooled, w_weights.transpose(1, 2)).squeeze(2)  # [b, c]
        
        # Combine features
        combined_features = h_features + w_features  # Element-wise addition
        
        # Pass through value function
        self._values = self._values.to(self.device)
        print(f"Final values device: {next(self._values.parameters()).device}")

        return self._values(combined_features).squeeze(-1)
 
     def _compute_embeddings_and_logits(self, batch):
        if isinstance(batch[Columns.OBS], dict):
            # Extract just the visual observations
            obs = batch[Columns.OBS]["observations"].float().to(self.device) / 255.0
        else:
            # Fallback for non-dict observations
            obs = batch[Columns.OBS].float().to(self.device) / 255.0

        
        obs = obs.to(self.device)
        obs = obs.permute(0, 3, 1, 2)
        self._base_cnn_stack = self._base_cnn_stack.to(self.device)
    
        self._logits = self._logits.to(self.device)

        embeddings = self._base_cnn_stack(obs)
        embeddings = embeddings.to(self.device)



        logits = self._logits(embeddings)

        logits = logits.to(self.device)
        embeddings = embeddings.to(self.device)

        # Handle action masking with device verification
        if isinstance(batch[Columns.OBS], dict) and "action_mask" in batch[Columns.OBS]:
            mask = batch[Columns.OBS]["action_mask"]
          
            
            # Ensure mask is on the same device as logits
            mask = mask.to(logits.device)
           
            
            # Create the mask
            inf_mask = torch.clamp(torch.log(mask), min=torch.finfo(torch.float32).min)
           
            
            # Apply the mask
            logits = logits + inf_mask
          
            
        return embeddings, logits

I have lots of print statements trying to debug the issue currently. Here is my training code:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from pprint import pprint
from ray.rllib.core.models.catalog import Catalog
from ray.tune.logger import JsonLoggerCallback, CSVLoggerCallback, TBXLoggerCallback
from ray.rllib.examples.envs.env_rendering_and_recording import EnvRenderCallback
import os
from discrete_CNN import CNN
import ray.rllib.models.torch.torch_distributions as td
from custom_training_metric_callback import EpisodeEndMetricsCallback

#env = EllipsePackingRGB()
env = EllipsePackingSimple()
#env = EllipsePacking()
metrics_callback = EpisodeEndMetricsCallback()

# Define a function that returns your callback instance
def get_callbacks():
    return metrics_callback

import os
os.environ["TORCH_DISTRIBUTED_BACKEND"] = "gloo"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


#class CustomCatalog(Catalog):
 #   def build(self):
  #      return {
   #         "model": CustomCNN_simple(
    #            observation_space=self.observation_space,
     #           action_space=self.action_space,
      #          model_config=self.model_config,
       #     )
         #  }

class CustomCatalog(Catalog):
    def build(self):
        return {
            "model": CustomCNN_simple(
                observation_space=self.observation_space,
                action_space=self.action_space,
                model_config=self.model_config,
            )
        }

spec = RLModuleSpec(
                module_class=CustomCNN_simple,
                catalog_class=CustomCatalog,
                inference_only=False, 
                observation_space=env.observation_space,
                action_space=env.action_space,
                model_config={
                    "conv_filters": [

                        [32, [5, 5], 2, "same"],   # Change stride to 2
                        [64, [3, 3], 2, "same"],   # Change stride to 2
                        [128, [3, 3], 1, "same"],  
                        [128, [3, 3], 1, "same"],
                    ],
                    "fcnet_hiddens": [512],
                    "fcnet_activation": "relu",
                },
)

#rl_module = spec.build()


config = (
    PPOConfig()
    .environment(EllipsePackingSimple)
    #.callbacks(get_callbacks)
    .resources(num_gpus=1) 
    .framework("torch")
    .debugging(log_level="DEBUG")
    .api_stack(
    enable_rl_module_and_learner=True,  
    enable_env_runner_and_connector_v2=True  
    ) 
    .learners(
        num_learners=0,
        num_gpus_per_learner=1,
        local_gpu_idx=0,
    )
    .env_runners(
        env_runner_cls=SingleAgentEnvRunner,
        num_env_runners=0,  
        sample_timeout_s=200,  
        rollout_fragment_length=128,  
        num_gpus_per_env_runner = 0,
    
    )
    .rl_module(
          # We need to explicitly specify here RLModule to use and
          # the catalog needed to build it.
          rl_module_spec=RLModuleSpec(
                module_class=CustomCNN_simple,
                catalog_class=CustomCatalog,
                inference_only=False, 
                observation_space=env.observation_space,
                action_space=env.action_space,
                model_config={
                    "conv_filters": [
                        [32, [5, 5], 2, "same"],   # Change stride to 2
                        [64, [3, 3], 2, "same"],   # Change stride to 2
                        [128, [3, 3], 1, "same"],  
                        [128, [3, 3], 1, "same"],
                    ],
                    "fcnet_hiddens": [512],
                    "fcnet_activation": "relu",
                },
          ),
      )

        .training(
    #gamma=0.9, lr=0.01, kl_coeff=0.3, train_batch_size_per_learner=512
    gamma=0.95, lr=0.001, kl_coeff=0.3, train_batch_size_per_learner=128 #re try with a lower learning rate
    )
    
    

)
def patched_logp(self, value, **kwargs):
    """Patched logp function that ensures device consistency"""
    
    # Print initial device states for debugging
    print(f"\n=== LOGP DEVICE DEBUG ===")
    if hasattr(value, 'device'):
        print(f"Input value device: {value.device}")
    
    # Check distribution's device
    target_device = None
    if hasattr(self, '_dist') and hasattr(self._dist, 'logits'):
        target_device = self._dist.logits.device
        print(f"Distribution logits device: {target_device}")
    elif hasattr(self, 'logits'):
        target_device = self.logits.device
        print(f"Direct logits device: {target_device}")
    else:
        print("No logits found to determine target device")
    
    # Move value to correct device if needed
    if hasattr(value, 'to') and hasattr(value, 'device') and target_device is not None:
        if value.device != target_device:
            print(f"Device mismatch! Moving value from {value.device} to {target_device}")
            value = value.to(target_device)
            print(f"Value successfully moved to: {value.device}")
        else:
            print(f"Devices already match: {value.device}")
    
    print(f"=== END LOGP DEVICE DEBUG ===\n")
    
    # Call the original log_prob method
    try:
        result = self._dist.log_prob(value, **kwargs)
        print(f"logp result device: {result.device if hasattr(result, 'device') else 'no device attr'}")
        return result
    except Exception as e:
        print(f"Error in log_prob computation: {e}")
        print(f"Final value device: {value.device if hasattr(value, 'device') else 'no device attr'}")
        print(f"Distribution device: {target_device}")
        raise e

# Apply the patch
td.TorchDistribution.logp = patched_logp

algo = config.build_algo()
print(algo.get_module())
# Force the entire algorithm to CUDA


# Specify the correct path using raw string
checkpoint_dir = r'#####
# Ensure the checkpoint directory exists before saving
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)


# Update the metrics DataFrame definition to include fill percentage
metrics_df = pd.DataFrame(columns=[
    "iteration", "total_loss", "policy_loss", "vf_loss", "entropy", "mean_kl_loss",
    "cpu_util_percent", "gpu_util_percent0",
    "num_env_steps_sampled", "num_agent_steps_sampled", "num_module_steps_sampled",
    "time_since_restore", "sample", "final_fill_percentage"
])

        




for i in range(2):
        result = algo.train()
        result.pop("config")
        
        # Extract available metrics from the result
        iteration = result['training_iteration']

        # Get the fill percentage from episodes that terminated in this iteration
        final_fill_percentage = 0.0  # Default value
      
        
        # Learner metrics
        total_loss = result['learners']['default_policy'].get('total_loss', 0)
        policy_loss = result['learners']['default_policy'].get('policy_loss', 0)
        vf_loss = result['learners']['default_policy'].get('vf_loss', 0)
        entropy = result['learners']['default_policy'].get('entropy', 0)
        # Use 'mean_kl_loss' instead of 'kl'
        mean_kl_loss = result['learners']['default_policy'].get('mean_kl_loss', 0)
        
        # Performance metrics
        cpu_util_percent = result['perf'].get('cpu_util_percent', 0)
        gpu_util_percent0 = result['perf'].get('gpu_util_percent0', 0)
        
        # Environment runner metrics
        num_env_steps_sampled = result['env_runners'].get('num_env_steps_sampled', 0)
        num_agent_steps_sampled = result['env_runners']['num_agent_steps_sampled'].get('default_agent', 0)
        num_module_steps_sampled = result['env_runners']['num_module_steps_sampled'].get('default_policy', 0)
        
        # Other metrics
        time_since_restore = result.get('time_since_restore', 0)
        sample = result['env_runners'].get('sample', 0)

        # Create a new row with extracted metrics
        new_row = {
            "iteration": iteration,
            "total_loss": total_loss,
            "policy_loss": policy_loss,
            "vf_loss": vf_loss,
            "entropy": entropy,
            "mean_kl_loss": mean_kl_loss,  # Changed from 'kl' to 'mean_kl_loss'
            "cpu_util_percent": cpu_util_percent,
            "gpu_util_percent0": gpu_util_percent0,
            "num_env_steps_sampled": num_env_steps_sampled,
            "num_agent_steps_sampled": num_agent_steps_sampled,
            "num_module_steps_sampled": num_module_steps_sampled,
            "time_since_restore": time_since_restore,
            "sample": sample
        }



        
        # Print progress
        print(f"Completed iteration {i}, DataFrame now has {len(metrics_df)} rows")


custom_metrics_df = metrics_callback.get_episode_df()
custom_metrics_df.to_csv("final_results.csv")




Here is the outcome:




step: 38, reward: 1.0, terminated: False, truncated: False, info: {'covered_pixels': np.int64(9690), 'total_pixels': np.int64(10913), 'fill_percentage': np.float64(88.79318244295794), 'num_ellipses_placed': 38, 'steps_taken': 38, 'total_reward': 38.0, 'valid_actions_remaining': np.float32(12.0), 'total_actions': 362, 'action_mask_ratio': np.float32(0.03314917)}

=== LOGP DEVICE DEBUG ===
Input value device: cuda:0
Distribution logits device: cuda:0
Devices already match: cuda:0
=== END LOGP DEVICE DEBUG ===

logp result device: cuda:0
ACCEPTING: All corners within PTV circle
Masked action 216 at local coords (70, 10), global coords (80, 20)
Masked action 236 at local coords (75, 15), global coords (85, 25)
Masked 2 actions due to ellipse placement at (85, 25)
Remaining valid actions: 10.0
step: 39, reward: 1.0, terminated: False, truncated: False, info: {'covered_pixels': np.int64(9839), 'total_pixels': np.int64(10913), 'fill_percentage': np.float64(90.15852652799413), 'num_ellipses_placed': 39, 'steps_taken': 39, 'total_reward': 39.0, 'valid_actions_remaining': np.float32(10.0), 'total_actions': 362, 'action_mask_ratio': np.float32(0.027624309)}
Embeddings device in compute_values: cuda:0
Final values device: cuda:0

=== DEBUG: _forward_train ===
Batch after move_to_device:
  loss_mask: device=cuda:0
  terminateds: device=cuda:0
  obs (dict):
    action_mask: device=cuda:0
    observations: device=cuda:0
  actions: device=cuda:0
  rewards: device=cuda:0
  truncateds: device=cuda:0
  action_dist_inputs: device=cuda:0
  action_logp: device=cuda:0
  weights_seq_no: device=cuda:0
  advantages: device=cuda:0
  value_targets: device=cuda:0
Actions before fix: device=cuda:0, dtype=torch.int32
Actions after fix: device=cuda:0, dtype=torch.int32
Final outputs:
  Logits: device=cuda:0, shape=torch.Size([128, 362])
  Embeddings: device=cuda:0, shape=torch.Size([128, 128, 35, 35])
  Actions: device=cuda:0, shape=torch.Size([128])
=== END _forward_train DEBUG ===


=== LOGP DEVICE DEBUG ===
Input value device: cpu
Distribution logits device: cuda:0
Device mismatch! Moving value from cpu to cuda:0
Value successfully moved to: cuda:0
=== END LOGP DEVICE DEBUG ===

logp result device: cuda:0

Hello Manisha! I’ll try to help to the best of my ability:

The error "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" means that, during loss computation, at least one tensor (likely from the batch) is still on the CPU while others are on CUDA. Even if your model and most batch fields are on CUDA, a single CPU tensor (often actions or action_logp ) will trigger this error. Your debug output says Input value device: cpu (under "=== LOGP DEVICE DEBUG ===
") so it means you might be using your CPU… somewhere.

There a few ways to try to make sure you are only doing things on your GPU. RLlib has a built-in utility for device placement called convert_to_torch_tensor utility to recursively move all batch fields to the target device. If you are manually moving tensors, you may miss nested fields or certain batch keys. See convert_to_torch_tensor.

You have to make sure that all tensors involved in the computation are moved to the same device. You can also try to use tensor.to(device) to move tensors to the desired device. For example, if you want to move a tensor to the GPU, use tensor.to('cuda') .

Some helpful docs:

Hi @manisha-waterston,

This sounds really frustrating. I hope you sort it out soon.

I think the issue may be you explicitly deciding on and moving the model and tensors to a device. RLLIB should be setup to handle this implicitly for you. Additionally, the device canchange indifferent parts of theprocess. For example, with the configuration youshared, the learnerwill use a GPU but the rollout workers will not. I think the device setting is incorrect for the rollout workers with your current configuration.

One quick way to test this is to set

num_gpu_learners = 0.1 and num_gpus_per_env_runner = 0.1

That way both processes are definitely on the GPU.