Reproducing ML-Agents Results with RLlib?

Hi all,

I’m working on a environment suite since almost 3 years, now is the time to publish (Neurips?) my baby. I want to use RLlib as my training example framework for my baseline, however I cannot reproduce the results I get when training with ML-Agents. In contrast I have trained the ML-Agents example 3DBall and I was able to reproduce the results with RLlib. So I guess this must be a training configuration problem? I actually matched everything to the 3DBall environment, but still no luck. Any sort of help or idea where else I could look would be much appreciated.

3DBall Results with ML-Agents:

In this case we have 12 agents

Result trained using ML-Agents (trained with a build no in editor)

Config:

behaviors:
  3DBall:
    trainer_type: ppo
    hyperparameters:
      batch_size: 64
      buffer_size: 12000
      learning_rate: 0.0003
      beta: 0.001
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 128
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 500000
    time_horizon: 1000
    summary_freq: 12000

image

→ Here we can see roughly 100 reward mean across 12 agents.

3DBall Results Reproduced in RLlib:

RLLib training script:

import os
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
from dotenv import load_dotenv
import wandb
from ray.air.integrations.wandb import WandbLoggerCallback

load_dotenv()
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

if __name__ == "__main__":
    wandb.login()
    ray.init()

    tune.register_env(
        "unity3d",
        lambda c: Unity3DEnv(
            file_name="src/hivex/dev_environments/3DBall/UnityEnvironment.exe",
            no_graphics=True,
            episode_horizon=1000,
        ),
    )

    # Get policies (different agent types; "behaviors" in MLAgents) and
    # the mappings from individual agents to Policies.
    policies, policy_mapping_fn = Unity3DEnv.get_policy_configs_for_game("3DBall")

    config = (
        PPOConfig()
        .environment(
            "unity3d",
            env_config={
                "file_name": "src/hivex/dev_environments/3DBall/UnityEnvironment.exe",
                "episode_horizon": 1000,
            },
        )
        .framework("torch")
        # For running in editor, force to use just one Worker (we only have
        # one Unity running)!
        .rollouts(
            num_rollout_workers=1,
            rollout_fragment_length=200,
        )
        .training(
            lr=0.0003,
            lambda_=0.99,
            gamma=0.99,
            sgd_minibatch_size=64,
            train_batch_size=12000,
            num_sgd_iter=3,
            clip_param=0.2,
            model={"fcnet_hiddens": [128, 128]},
        )
        .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

    stop = {
        "timesteps_total": 500000,  # per agent
    }

    wandb_callback = WandbLoggerCallback(
        project="3DBall_test",
        api_key=WANDB_API_KEY,
        upload_checkpoints=True,
        save_checkpoints=True,
        group=f"Test",
        log_config=True,
    )

    # Run the experiment.
    results = tune.Tuner(
        "PPO",
        param_space=config.to_dict(),
        run_config=air.RunConfig(
            stop=stop,
            verbose=1,
            checkpoint_config=air.CheckpointConfig(
                checkpoint_frequency=5,
                checkpoint_at_end=True,
            ),
            callbacks=[wandb_callback],
        ),
    ).fit()

    ray.shutdown()
    wandb.finish()

→ Here we can see roughly 1200 cumulative reward for 12 agents, which makes sense I think. I would consider this as “reproduced”.

Custom Environment trained with ML-Agents:

In this case we have 8 agents

Result trained using ML-Agents (trained with a build no in editor)

Config:

behaviors:
  WindFarmControl:
    trainer_type: ppo
    hyperparameters:
      batch_size: 256
      buffer_size: 8000
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 64
      num_layers: 2
    reward_signals:
      extrinsic:
        gamma: 0.9
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 4000000 # 1 arenas a 8 turbines a 5000 steps = 40000 * 100
    time_horizon: 1000
    summary_freq: 8000 # 1 arenas a 8 turbines a 5000 steps = 40000
    threaded: true

image

→ This makes sense as we have 5000 steps per episode and the max reward per step is 1.0.

Custom Environment Results in RLlib:

RLlib training script:

import os

import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from dotenv import load_dotenv
import wandb
from ray.air.integrations.wandb import WandbLoggerCallback
from hivex.training.baseline.unity3d_env import Unity3DEnv

load_dotenv()
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

if __name__ == "__main__":
    wandb.login()
    ray.init()

    tune.register_env(
        "unity3d",
        lambda c: Unity3DEnv(
            file_name="src/hivex/dev_environments/Hivex_WindFarmControl_win_1/Hivex_WindfarmControl.exe",
            no_graphics=True,
            episode_horizon=1000,
        ),
    )

    # Get policies (different agent types; "behaviors" in MLAgents) and
    # the mappings from individual agents to Policies.
    policies, policy_mapping_fn = Unity3DEnv.get_policy_configs_for_game(
        "WindFarmControl"
    )

    config = (
        PPOConfig()
        .environment(
            "unity3d",
            env_config={
                "file_name": "src/hivex/dev_environments/Hivex_WindFarmControl_win_1/Hivex_WindfarmControl.exe",
                "episode_horizon": 1000,
            },
        )
        .framework("torch")
        # For running in editor, force to use just one Worker (we only have
        # one Unity running)!
        .rollouts(
            num_rollout_workers=1,
            rollout_fragment_length=200,
        )
        .training(
            lr=0.0003,
            lambda_=0.95,
            gamma=0.99,
            sgd_minibatch_size=256,
            # this dictates how this is logged to wandb
            # steps on X axis in wandb will be train_batch_size * agent count
            train_batch_size=8000,
            num_sgd_iter=3,
            clip_param=0.2,
            model={"fcnet_hiddens": [64, 64]},
        )
        .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

    stop = {
        "timesteps_total": 500000,
    }

    wandb_callback = WandbLoggerCallback(
        project="WFC_test_1",
        api_key=WANDB_API_KEY,
        upload_checkpoints=True,
        save_checkpoints=True,
        group=f"Test",
        log_config=True,
    )

    # Run the experiment.
    results = tune.Tuner(
        "PPO",
        param_space=config.to_dict(),
        run_config=air.RunConfig(
            stop=stop,
            verbose=1,
            checkpoint_config=air.CheckpointConfig(
                checkpoint_frequency=5,
                checkpoint_at_end=True,
            ),
            callbacks=[wandb_callback],
        ),
    ).fit()

    ray.shutdown()
    wandb.finish()

Additionally I added to unity3d_env.py observation and action specs like so:

@staticmethod
    def get_policy_configs_for_game(
        game_name: str,
    ) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
        ....
        # The RLlib server must know about the Spaces that the Client will be
        # using inside Unity3D, up-front.
        obs_spaces = {
            "WindFarmControl": Box(float("-inf"), float("inf"), (6,)),
            # 3DBall.
            "3DBall": Box(float("-inf"), float("inf"), (8,)),
            ...
        }
        action_spaces = {
            "WindFarmControl": MultiDiscrete([3]),
            # 3DBall.
            "3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
             ...
        }
        ...

Here is the problem: Expectation would to have a cumulative reward of roughly 32k, but instead we are at 14k. If we divide that by the number of agents (8), we are at 1750 reward per agent or mean reward, compared to the ML-Agent results, that is less than half. ML-Agents results show something around 4k mean reward.

Just for fun I ran 5mil instead of 500k steps to see where things converge using RLlib:

This is my unity3d_env.py in case this is relevant:

from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace
import logging
import numpy as np
import random
import time
from typing import Callable, Optional, Tuple

from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID

logger = logging.getLogger(__name__)


@PublicAPI
class Unity3DEnv(MultiAgentEnv):
    """A MultiAgentEnv representing a single Unity3D game instance.

    For an example on how to use this Env with a running Unity3D editor
    or with a compiled game, see:
    `rllib/examples/unity3d_env_local.py`
    For an example on how to use it inside a Unity game client, which
    connects to an RLlib Policy server, see:
    `rllib/examples/serving/unity3d_[client|server].py`

    Supports all Unity3D (MLAgents) examples, multi- or single-agent and
    gets converted automatically into an ExternalMultiAgentEnv, when used
    inside an RLlib PolicyClient for cloud/distributed training of Unity games.
    """

    # Default base port when connecting directly to the Editor
    _BASE_PORT_EDITOR = 5004
    # Default base port when connecting to a compiled environment
    _BASE_PORT_ENVIRONMENT = 5005
    # The worker_id for each environment instance
    _WORKER_ID = 0

    def __init__(
        self,
        file_name: str = None,
        port: Optional[int] = None,
        seed: int = 0,
        no_graphics: bool = False,
        timeout_wait: int = 300,
        episode_horizon: int = 1000,
    ):
        """Initializes a Unity3DEnv object.

        Args:
            file_name (Optional[str]): Name of the Unity game binary.
                If None, will assume a locally running Unity3D editor
                to be used, instead.
            port (Optional[int]): Port number to connect to Unity environment.
            seed: A random seed value to use for the Unity3D game.
            no_graphics: Whether to run the Unity3D simulator in
                no-graphics mode. Default: False.
            timeout_wait: Time (in seconds) to wait for connection from
                the Unity3D instance.
            episode_horizon: A hard horizon to abide to. After at most
                this many steps (per-agent episode `step()` calls), the
                Unity3D game is reset and will start again (finishing the
                multi-agent episode that the game represents).
                Note: The game itself may contain its own episode length
                limits, which are always obeyed (on top of this value here).
        """
        # Skip env checking as the nature of the agent IDs depends on the game
        # running in the connected Unity editor.
        self._skip_env_checking = True

        super().__init__()

        if file_name is None:
            print(
                "No game binary provided, will use a running Unity editor "
                "instead.\nMake sure you are pressing the Play (|>) button in "
                "your editor to start."
            )

        import mlagents_envs
        from mlagents_envs.environment import UnityEnvironment

        # Try connecting to the Unity3D game instance. If a port is blocked
        port_ = None
        while True:
            # Sleep for random time to allow for concurrent startup of many
            # environments (num_workers >> 1). Otherwise, would lead to port
            # conflicts sometimes.
            if port_ is not None:
                time.sleep(random.randint(1, 10))
            port_ = port or (
                self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR
            )
            # cache the worker_id and
            # increase it for the next environment
            worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
            Unity3DEnv._WORKER_ID += 1
            try:
                self.unity_env = UnityEnvironment(
                    file_name=file_name,
                    worker_id=worker_id_,
                    base_port=port_,
                    seed=seed,
                    no_graphics=no_graphics,
                    timeout_wait=timeout_wait,
                )
                print("Created UnityEnvironment for port {}".format(port_ + worker_id_))
            except mlagents_envs.exception.UnityWorkerInUseException:
                pass
            else:
                break

        # ML-Agents API version.
        self.api_version = self.unity_env.API_VERSION.split(".")
        self.api_version = [int(s) for s in self.api_version]

        # Reset entire env every this number of step calls.
        self.episode_horizon = episode_horizon
        # Keep track of how many times we have called `step` so far.
        self.episode_timesteps = 0

    def step(
        self, action_dict: MultiAgentDict
    ) -> Tuple[
        MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
    ]:
        """Performs one multi-agent step through the game.

        Args:
            action_dict: Multi-agent action dict with:
                keys=agent identifier consisting of
                [MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
                [Agent index, a unique MLAgent-assigned index per single agent]

        Returns:
            tuple:
                - obs: Multi-agent observation dict.
                    Only those observations for which to get new actions are
                    returned.
                - rewards: Rewards dict matching `obs`.
                - dones: Done dict with only an __all__ multi-agent entry in
                    it. __all__=True, if episode is done for all agents.
                - infos: An (empty) info dict.
        """
        from mlagents_envs.base_env import ActionTuple

        # Set only the required actions (from the DecisionSteps) in Unity3D.
        all_agents = []
        for behavior_name in self.unity_env.behavior_specs:
            # New ML-Agents API: Set all agents actions at the same time
            # via an ActionTuple. Since API v1.4.0.
            if self.api_version[0] > 1 or (
                self.api_version[0] == 1 and self.api_version[1] >= 4
            ):
                actions = []
                for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id:
                    key = behavior_name + "_{}".format(agent_id)
                    all_agents.append(key)
                    actions.append(action_dict[key])
                if actions:
                    if actions[0].dtype == np.float32:
                        action_tuple = ActionTuple(continuous=np.array(actions))
                    else:
                        action_tuple = ActionTuple(discrete=np.array(actions))
                    self.unity_env.set_actions(behavior_name, action_tuple)
            # Old behavior: Do not use an ActionTuple and set each agent's
            # action individually.
            else:
                for agent_id in self.unity_env.get_steps(behavior_name)[
                    0
                ].agent_id_to_index.keys():
                    key = behavior_name + "_{}".format(agent_id)
                    all_agents.append(key)
                    self.unity_env.set_action_for_agent(
                        behavior_name, agent_id, action_dict[key]
                    )
        # Do the step.
        self.unity_env.step()

        obs, rewards, terminateds, truncateds, infos = self._get_step_results()

        # Global horizon reached? -> Return __all__ truncated=True, so user
        # can reset. Set all agents' individual `truncated` to True as well.
        self.episode_timesteps += 1
        if self.episode_timesteps > self.episode_horizon:
            return (
                obs,
                rewards,
                terminateds,
                dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}),
                infos,
            )

        return obs, rewards, terminateds, truncateds, infos

    def reset(
        self, *, seed=None, options=None
    ) -> Tuple[MultiAgentDict, MultiAgentDict]:
        """Resets the entire Unity3D scene (a single multi-agent episode)."""
        self.episode_timesteps = 0
        self.unity_env.reset()
        obs, _, _, _, infos = self._get_step_results()
        return obs, infos

    def _get_step_results(self):
        """Collects those agents' obs/rewards that have to act in next `step`.

        Returns:
            Tuple:
                obs: Multi-agent observation dict.
                    Only those observations for which to get new actions are
                    returned.
                rewards: Rewards dict matching `obs`.
                dones: Done dict with only an __all__ multi-agent entry in it.
                    __all__=True, if episode is done for all agents.
                infos: An (empty) info dict.
        """
        obs = {}
        rewards = {}
        infos = {}
        for behavior_name in self.unity_env.behavior_specs:
            decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name)
            # Important: Only update those sub-envs that are currently
            # available within _env_state.
            # Loop through all envs ("agents") and fill in, whatever
            # information we have.
            for agent_id, idx in decision_steps.agent_id_to_index.items():
                key = behavior_name + "_{}".format(agent_id)
                os = tuple(o[idx] for o in decision_steps.obs)
                os = os[0] if len(os) == 1 else os
                obs[key] = os
                rewards[key] = (
                    decision_steps.reward[idx] + decision_steps.group_reward[idx]
                )
            for agent_id, idx in terminal_steps.agent_id_to_index.items():
                key = behavior_name + "_{}".format(agent_id)
                # Only overwrite rewards (last reward in episode), b/c obs
                # here is the last obs (which doesn't matter anyways).
                # Unless key does not exist in obs.
                if key not in obs:
                    os = tuple(o[idx] for o in terminal_steps.obs)
                    obs[key] = os = os[0] if len(os) == 1 else os
                rewards[key] = (
                    terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
                )

        # Only use dones if all agents are done, then we should do a reset.
        return obs, rewards, {"__all__": False}, {"__all__": False}, infos

    @staticmethod
    def get_policy_configs_for_game(
        game_name: str,
    ) -> Tuple[dict, Callable[[AgentID], PolicyID]]:

        # The RLlib server must know about the Spaces that the Client will be
        # using inside Unity3D, up-front.
        obs_spaces = {
            "WindFarmControl": Box(float("-inf"), float("inf"), (6,)),
            # 3DBall.
            "3DBall": Box(float("-inf"), float("inf"), (8,)),
            # 3DBallHard.
            "3DBallHard": Box(float("-inf"), float("inf"), (45,)),
            # GridFoodCollector
            "GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
            # Pyramids.
            "Pyramids": TupleSpace(
                [
                    Box(float("-inf"), float("inf"), (56,)),
                    Box(float("-inf"), float("inf"), (56,)),
                    Box(float("-inf"), float("inf"), (56,)),
                    Box(float("-inf"), float("inf"), (4,)),
                ]
            ),
            # SoccerTwos.
            "SoccerPlayer": TupleSpace(
                [
                    Box(-1.0, 1.0, (264,)),
                    Box(-1.0, 1.0, (72,)),
                ]
            ),
            # SoccerStrikersVsGoalie.
            "Goalie": Box(float("-inf"), float("inf"), (738,)),
            "Striker": TupleSpace(
                [
                    Box(float("-inf"), float("inf"), (231,)),
                    Box(float("-inf"), float("inf"), (63,)),
                ]
            ),
            # Sorter.
            "Sorter": TupleSpace(
                [
                    Box(
                        float("-inf"),
                        float("inf"),
                        (
                            20,
                            23,
                        ),
                    ),
                    Box(float("-inf"), float("inf"), (10,)),
                    Box(float("-inf"), float("inf"), (8,)),
                ]
            ),
            # Tennis.
            "Tennis": Box(float("-inf"), float("inf"), (27,)),
            # VisualHallway.
            "VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
            # Walker.
            "Walker": Box(float("-inf"), float("inf"), (212,)),
            # FoodCollector.
            "FoodCollector": TupleSpace(
                [
                    Box(float("-inf"), float("inf"), (49,)),
                    Box(float("-inf"), float("inf"), (4,)),
                ]
            ),
        }
        action_spaces = {
            "WindFarmControl": MultiDiscrete([3]),
            # 3DBall.
            "3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
            # 3DBallHard.
            "3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32),
            # GridFoodCollector.
            "GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
            # Pyramids.
            "Pyramids": MultiDiscrete([5]),
            # SoccerStrikersVsGoalie.
            "Goalie": MultiDiscrete([3, 3, 3]),
            "Striker": MultiDiscrete([3, 3, 3]),
            # SoccerTwos.
            "SoccerPlayer": MultiDiscrete([3, 3, 3]),
            # Sorter.
            "Sorter": MultiDiscrete([3, 3, 3]),
            # Tennis.
            "Tennis": Box(-1.0, 1.0, (3,)),
            # VisualHallway.
            "VisualHallway": MultiDiscrete([5]),
            # Walker.
            "Walker": Box(-1.0, 1.0, (39,)),
            # FoodCollector.
            "FoodCollector": MultiDiscrete([3, 3, 3, 2]),
        }

        # Policies (Unity: "behaviors") and agent-to-policy mapping fns.
        if game_name == "SoccerStrikersVsGoalie":
            policies = {
                "Goalie": PolicySpec(
                    observation_space=obs_spaces["Goalie"],
                    action_space=action_spaces["Goalie"],
                ),
                "Striker": PolicySpec(
                    observation_space=obs_spaces["Striker"],
                    action_space=action_spaces["Striker"],
                ),
            }

            def policy_mapping_fn(agent_id, episode, worker, **kwargs):
                return "Striker" if "Striker" in agent_id else "Goalie"

        elif game_name == "SoccerTwos":
            policies = {
                "PurplePlayer": PolicySpec(
                    observation_space=obs_spaces["SoccerPlayer"],
                    action_space=action_spaces["SoccerPlayer"],
                ),
                "BluePlayer": PolicySpec(
                    observation_space=obs_spaces["SoccerPlayer"],
                    action_space=action_spaces["SoccerPlayer"],
                ),
            }

            def policy_mapping_fn(agent_id, episode, worker, **kwargs):
                return "BluePlayer" if "1_" in agent_id else "PurplePlayer"

        else:
            policies = {
                game_name: PolicySpec(
                    observation_space=obs_spaces[game_name],
                    action_space=action_spaces[game_name],
                ),
            }

            def policy_mapping_fn(agent_id, episode, worker, **kwargs):
                return game_name

        return policies, policy_mapping_fn

pip-freeze:

absl-py==2.1.0
aiosignal==1.3.1
appdirs==1.4.4
astunparse==1.6.3
attrs==23.2.0
black==24.4.2
cachetools==5.3.3
cattrs==1.5.0
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
dm-tree==0.1.8
docker-pycreds==0.4.0
exceptiongroup==1.2.1
Farama-Notifications==0.0.4
filelock==3.14.0
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.3.1
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.43
google-auth==2.29.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
GPUtil==1.4.0
grpcio==1.63.0
gymnasium==0.28.1
h5py==3.11.0
-e git+https://github.com/philippds/hivex@9133b0de4d05038a5b221961b1e385fa8618cbb2#egg=hivex
idna==3.7
imageio==2.34.1
importlib_metadata==7.1.0
importlib_resources==6.4.0
iniconfig==2.0.0
intel-openmp==2021.4.0
jax-jumpy==1.0.0
Jinja2==3.1.3
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
keras==2.13.1
lazy_loader==0.4
libclang==18.1.1
lz4==4.3.3
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mkl==2021.4.0
mlagents==0.28.0
mlagents-envs==0.28.0
mpmath==1.3.0
msgpack==1.0.8
mypy-extensions==1.0.0
networkx==3.1
numpy==1.24.3
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==24.0
pandas==2.0.3
pathspec==0.12.1
pillow==10.3.0
pkgutil_resolve_name==1.3.10
platformdirs==4.2.1
pluggy==1.5.0
protobuf==3.20.3
psutil==5.9.8
pyarrow==16.0.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==1.10.14
Pygments==2.17.2
pypiwin32==223
pytest==8.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2024.1
PyWavelets==1.4.1
pywin32==306
PyYAML==6.0.1
ray==2.10.0
referencing==0.35.1
requests==2.31.0
requests-oauthlib==2.0.0
rich==13.7.1
rpds-py==0.18.0
rsa==4.9
scikit-image==0.21.0
scipy==1.10.1
sentry-sdk==2.0.1
setproctitle==1.3.3
shellingham==1.5.4
six==1.16.0
smmap==5.0.1
sympy==1.12
tbb==2021.12.0
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.13.0
tensorflow-estimator==2.13.0
tensorflow-intel==2.13.0
tensorflow-io-gcs-filesystem==0.31.0
termcolor==2.4.0
tifffile==2023.7.10
tomli==2.0.1
torch==2.3.0
torchvision==0.18.0
typer==0.12.3
typing_extensions==4.5.0
tzdata==2024.1
urllib3==2.2.1
wandb==0.16.6
Werkzeug==3.0.2
wrapt==1.16.0
zipp==3.18.1

python -V:

Python 3.8.1

Any help would be extremely valuable as I’m running against beginning of June deadline and I really want to publish this with RLlib. If I cannot fix this I will have to switch framework which I don’t want to do :grimacing:

Let me know if there is any more info I need to provide,
Thank you very much for your time!!!

Hi @Philipp_D_Siedler,

My guess given how large your rewards are and how are your returns will be is that the difference might be coming from how the value loss is computed in the two rl libraries. Have you compared these values between ml agents and rllib?

mlagents(I think) :

rllib:

Hi @mannyv - thank you very much for your swift response. Your thesis essentially questions the reward-scale of the environment. So before looking into losses I have run a couple of experiments. I re-ran bot ML-Agents setups and RLlib setup with a reward scaling of 0.02 so that the range roughly matches the reward received by the 3DBall environment, which is 0-100 roughly. I also introduced tracking a metric that is independent of the reward / loss calculation for better comparability, the metric is called “Individual Performance”

ML-Agents reward scale experiment:

→ Red is reward scale 1.0, reward ranging from 0 to 5000
Blue is reward scale 0.02, reward ranging from 0 to 100

RLlib reward scale experiment:

→ Blue is Reward scale 1.0, reward ranging from 0 to 5000
Red is reward scale 0.02, reward raning from 0 to 100

This seems to have no effect, but I’m not certain if my experiments are sufficient to draw such a conclusion. What’s your view on this? Do you still think this might be on loss calculation?

Thank you!!

PS: I changed my script to not use the Tune.tuner anymore, so I had more insight in what is being logged etc.

  algo = param_space.build()

  for i in range(100):
      result = algo.train()

I did notice that num_env_steps_trained stays at 0, while sampled increases. The 0 appeared because I rolled back from ray==2.10.0 to ray==2.8.1.

...
num_env_steps_sampled: 500000
num_env_steps_sampled_this_iter: 5000
num_env_steps_sampled_throughput_per_sec: 95.09055955300937
num_env_steps_trained: 0
...

consider taking a look at this issue. Fix wrong returns in multi agent partially returned episodes by drblallo · Pull Request #45057 · ray-project/ray · GitHub it has made ray missreport every multiagent setup i run until i found it and fixed it. I am not sure it is your issue, but if you are running a multi agent setup, you may have encountered it.