RLlib Batch Postprocessing has steps from other trajectories

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

In short, this is my error:

ValueError: ('Batches sent to postprocessing must only contain steps from a single trajectory.', SampleBatch(33: ['obs', 'new_obs', 'actions', 'rewards', 'terminateds', 'truncateds', 'infos', 'eps_id', 'unroll_id', 'agent_index', 't', 'vf_preds', 'action_dist_inputs', 'action_logp']))

And I want to know how to best get to the root issue of what I’m facing and how to fix it.

I went about debugging and tried this script and my code passed with flying colors:

def _test_environment(environment: MultiAgentDroneEnvironment):
    observations, rewards, dones, truncs, infos = get_readings(environment, 1000)
    
    # Check the initial reward structure
    assert isinstance(rewards[0], dict), "Initial rewards must be a dict"
    for reward in tqdm(rewards, desc="Checking rewards"):
        assert isinstance(reward, dict), "Each reward must be a dict"
        for agent_id, value in reward.items():
            assert isinstance(agent_id, Hashable), f"Agent ID {agent_id} is not hashable"
            assert isinstance(value, float), f"Value for agent_id {agent_id} is not of type float"
    
    # Check the initial done structure
    assert isinstance(dones[0], dict), "Initial dones must be a dict"
    for done in tqdm(dones, desc="Checking dones"):
        assert isinstance(done, dict), "Each done must be a dict"
        for agent_id, value in done.items():
            assert isinstance(value, bool), f"Value for agent_id {agent_id} is not of type bool"
    # All dones must have a __all__ key
    assert all("__all__" in done.keys() for done in dones), "All dones must have a __all__ key"
    
    # Check the initial trunc structure
    assert isinstance(truncs[0], dict), "Initial truncs must be a dict"
    for trunc in tqdm(truncs, desc="Checking truncs"):
        assert isinstance(trunc, dict), "Each trunc must be a dict"
        for agent_id, value in trunc.items():
            assert isinstance(value, bool), f"Value for agent_id {agent_id} is not of type bool"
    # All truncs must have a __all__ key
    assert all("__all__" in trunc.keys() for trunc in truncs), "All truncs must have a __all__ key"
    
    # Check the initial info structure
    assert isinstance(infos[0], dict), "Initial infos must be a dict"
    for info in tqdm(infos, desc="Checking infos"):
        assert isinstance(info, dict), "Each info must be a dict"
        for agent_id, sub_info in info.items():
            assert isinstance(sub_info, dict), f"Value for agent_id {agent_id} is not of type dict"
            for key, value in sub_info.items():
                assert isinstance(key, str), "Key in info must be a string"
                assert isinstance(value, float), "Value in info must be a float"
    
    # Check that all observations are in the observation space
    for observation in tqdm(observations, desc="Checking observations"):
        assert environment.observation_space.contains(observation), "Observation not in observation space"


if __name__ == '__main__':
    env = SimpleHoverDroneEnv(render_mode=None)
    _test_environment(env)

As such, I am 100% confident that my issue is not related to observation shapes nor sizes, so I went about debugging the episodal system here:

num_episodes = 100  # Number of episodes to run the test

for episode in range(num_episodes):
    observations, log = env.reset()
    dones = {agent_id: False for agent_id in observations.keys()}
    dones["__all__"] = False
    step_count = 0

    while not dones["__all__"]:
        actions = {agent_id: env.action_space[agent_id].sample() for agent_id in observations.keys()}  # Sample random actions

        next_observations, rewards, dones, truncs, infos = env.step(actions)
        
        if dones["__all__"]:
            print(f"Episode {episode + 1} finished after {step_count + 1} steps for all agents.")
        step_count += 1

And my environment works perfectly as intended.

When I simulate my environment, I get the exact random action space results I need from running the simulations (MuJoCo sims). I tried to debug the postprocessing batch by having num_workers = 0 and setting breakpoints internally in ray’s python package code and couldn’t make sense of much of what I was getting. My project can be found here. Below is the training script I’ve tried running to get a minimal example running (I don’t expect any amount of real training to happen, just want to establish a pipeline and some form of reward maximization)

from typing import Dict, Union, Optional

import numpy as np
from ray.rllib import BaseEnv, Policy, RolloutWorker
from ray.rllib.algorithms import PPOConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.utils.typing import PolicyID, AgentID

from environments.SimpleDroneEnvs import SimpleHoverDroneEnv
from ray.tune.registry import register_env
import ray

import logging

logging.basicConfig(level=logging.INFO)


class BatchIntegrityChecker(DefaultCallbacks):
    def on_postprocess_traj(
            self,
            *,
            worker,
            base_env,
            policies: Dict[str, Policy],
            episode: Episode,
            agent_id: Optional[str] = None,
            policy_id: Optional[str] = None,
            postprocessed_batch,
            original_batches,
            **kwargs,
    ) -> None:
        # Extract episode IDs from the batch
        episode_ids = postprocessed_batch['eps_id']
        unique_episode_ids = np.unique(episode_ids)

        if len(unique_episode_ids) > 1:
            # Log detailed information about the problematic batch
            logging.error(f"Batch contains steps from multiple trajectories! Unique episode IDs: {unique_episode_ids}")
            logging.error(f"Batch size: {postprocessed_batch.count}")
            logging.error(f"Actions in the batch: {postprocessed_batch['actions']}")
            logging.error(f"Rewards in the batch: {postprocessed_batch['rewards']}")
            logging.error(f"Dones in the batch: {postprocessed_batch['dones']}")
            logging.error(f"Episode IDs in the batch: {episode_ids}")

            # Raise an error to halt training and investigate
            raise ValueError("Batch contains steps from multiple trajectories!")
        else:
            logging.info(f"Batch verification passed: All steps belong to episode ID {unique_episode_ids[0]}.")


def env_creator(env_config):
    return SimpleHoverDroneEnv(
        n_agents=env_config.get("n_agents", 5),
        spacing=env_config.get("spacing", 3.0),
        map_bounds=env_config.get("map_bounds", np.array([[-10, -10, -0.01], [10, 10, 10]])),
        spawn_box=env_config.get("spawn_box", np.array([[-1, -1, 0.1], [5, 1, 1]])),
        dt=env_config.get("dt", 0.01),
        render_mode=env_config.get("render_mode", None),
        fps=env_config.get("fps", 60),
        sim_rate=env_config.get("sim_rate", 1),
        update_distances=env_config.get("update_distances", True)
    )


register_env("SimpleHoverDroneEnv", env_creator)

# Initialize Ray
ray.init(ignore_reinit_error=True)

# Configure the PPO algorithm
config = PPOConfig() \
    .environment(env="SimpleHoverDroneEnv") \
    .framework("torch") \
    .rollouts(num_rollout_workers=0) \
    .training(gamma=0.99, lr=5e-4, train_batch_size=4000) \
    .resources(num_gpus=0) \
    .evaluation(evaluation_interval=10, evaluation_num_episodes=10) \
    .reporting(min_time_s_per_iteration=1)\

# Build the PPO algorithm
ppo_agent = config.build()

# Train the PPO agent
for i in range(100):  # Number of training iterations
    result = ppo_agent.train()
    print(f"Iteration: {i}, episode_reward_mean: {result['episode_reward_mean']}")

    if i % 10 == 0:  # Save the model every 10 iterations
        checkpoint = ppo_agent.save()
        print(f"Checkpoint saved at: {checkpoint}")

When I try to make a similar environment but make it a single player environment, I find that the error doesn’t happen and that I get workable results, so the only conclusion I’m able to achieve is that I am either not using the MultiEnvironment system properly or there is a bug, and I’m unable to figure out which. Below is the full error log if this somehow helps:

/home/engineering-geek/miniforge3/envs/RL-UAV/bin/python /home/engineering-geek/PycharmProjects/RL-UAV/a.py 
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/subprocess.py:1883: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _fork_exec(
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/subprocess.py:1883: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _fork_exec(
2024-03-23 12:06:03,462	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 
2024-03-23 12:06:04,113	WARNING deprecation.py:50 -- DeprecationWarning: `AlgorithmConfig.evaluation(evaluation_num_episodes=..)` has been deprecated. Use `AlgorithmConfig.evaluation(evaluation_duration=.., evaluation_duration_unit='episodes')` instead. This will raise an error in the future!
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:483: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
Traceback (most recent call last):
  File "/home/engineering-geek/PycharmProjects/RL-UAV/a.py", line 87, in <module>
    result = ppo_agent.train()
             ^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 339, in train
    result = self.step()
             ^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 852, in step
    results, train_iter_ctx = self._run_one_training_iteration()
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 3042, in _run_one_training_iteration
    results = self.training_step()
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 407, in training_step
    train_batch = synchronous_parallel_sample(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py", line 80, in synchronous_parallel_sample
    sample_batches = [worker_set.local_worker().sample()]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 694, in sample
    batches = [self.input_reader.next()]
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 91, in next
    batches = [self.get_data()]
               ^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 276, in get_data
    item = next(self._env_runner)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 344, in run
    outputs = self.step()
              ^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 370, in step
    active_envs, to_eval, outputs = self._process_observations(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 688, in _process_observations
    self._handle_done_episode(
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 838, in _handle_done_episode
    self._build_done_episode(env_id, is_done, outputs)
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 727, in _build_done_episode
    episode.postprocess_episode(
  File "/home/engineering-geek/miniforge3/envs/RL-UAV/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py", line 303, in postprocess_episode
    raise ValueError(
ValueError: ('Batches sent to postprocessing must only contain steps from a single trajectory.', SampleBatch(33: ['obs', 'new_obs', 'actions', 'rewards', 'terminateds', 'truncateds', 'infos', 'eps_id', 'unroll_id', 'agent_index', 't', 'vf_preds', 'action_dist_inputs', 'action_logp']))

Process finished with exit code 1