Hi all,
I’m working on a environment suite since almost 3 years, now is the time to publish (Neurips?) my baby. I want to use RLlib as my training example framework for my baseline, however I cannot reproduce the results I get when training with ML-Agents. In contrast I have trained the ML-Agents example 3DBall and I was able to reproduce the results with RLlib. So I guess this must be a training configuration problem? I actually matched everything to the 3DBall environment, but still no luck. Any sort of help or idea where else I could look would be much appreciated.
3DBall Results with ML-Agents:
In this case we have 12 agents
Result trained using ML-Agents (trained with a build no in editor)
Config:
behaviors:
3DBall:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
→ Here we can see roughly 100 reward mean across 12 agents.
3DBall Results Reproduced in RLlib:
RLLib training script:
import os
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
from dotenv import load_dotenv
import wandb
from ray.air.integrations.wandb import WandbLoggerCallback
load_dotenv()
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
if __name__ == "__main__":
wandb.login()
ray.init()
tune.register_env(
"unity3d",
lambda c: Unity3DEnv(
file_name="src/hivex/dev_environments/3DBall/UnityEnvironment.exe",
no_graphics=True,
episode_horizon=1000,
),
)
# Get policies (different agent types; "behaviors" in MLAgents) and
# the mappings from individual agents to Policies.
policies, policy_mapping_fn = Unity3DEnv.get_policy_configs_for_game("3DBall")
config = (
PPOConfig()
.environment(
"unity3d",
env_config={
"file_name": "src/hivex/dev_environments/3DBall/UnityEnvironment.exe",
"episode_horizon": 1000,
},
)
.framework("torch")
# For running in editor, force to use just one Worker (we only have
# one Unity running)!
.rollouts(
num_rollout_workers=1,
rollout_fragment_length=200,
)
.training(
lr=0.0003,
lambda_=0.99,
gamma=0.99,
sgd_minibatch_size=64,
train_batch_size=12000,
num_sgd_iter=3,
clip_param=0.2,
model={"fcnet_hiddens": [128, 128]},
)
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
stop = {
"timesteps_total": 500000, # per agent
}
wandb_callback = WandbLoggerCallback(
project="3DBall_test",
api_key=WANDB_API_KEY,
upload_checkpoints=True,
save_checkpoints=True,
group=f"Test",
log_config=True,
)
# Run the experiment.
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop=stop,
verbose=1,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=5,
checkpoint_at_end=True,
),
callbacks=[wandb_callback],
),
).fit()
ray.shutdown()
wandb.finish()
→ Here we can see roughly 1200 cumulative reward for 12 agents, which makes sense I think. I would consider this as “reproduced”.
Custom Environment trained with ML-Agents:
In this case we have 8 agents
Result trained using ML-Agents (trained with a build no in editor)
Config:
behaviors:
WindFarmControl:
trainer_type: ppo
hyperparameters:
batch_size: 256
buffer_size: 8000
learning_rate: 0.0003
beta: 0.005
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 64
num_layers: 2
reward_signals:
extrinsic:
gamma: 0.9
strength: 1.0
keep_checkpoints: 5
max_steps: 4000000 # 1 arenas a 8 turbines a 5000 steps = 40000 * 100
time_horizon: 1000
summary_freq: 8000 # 1 arenas a 8 turbines a 5000 steps = 40000
threaded: true
→ This makes sense as we have 5000 steps per episode and the max reward per step is 1.0.
Custom Environment Results in RLlib:
RLlib training script:
import os
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from dotenv import load_dotenv
import wandb
from ray.air.integrations.wandb import WandbLoggerCallback
from hivex.training.baseline.unity3d_env import Unity3DEnv
load_dotenv()
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
if __name__ == "__main__":
wandb.login()
ray.init()
tune.register_env(
"unity3d",
lambda c: Unity3DEnv(
file_name="src/hivex/dev_environments/Hivex_WindFarmControl_win_1/Hivex_WindfarmControl.exe",
no_graphics=True,
episode_horizon=1000,
),
)
# Get policies (different agent types; "behaviors" in MLAgents) and
# the mappings from individual agents to Policies.
policies, policy_mapping_fn = Unity3DEnv.get_policy_configs_for_game(
"WindFarmControl"
)
config = (
PPOConfig()
.environment(
"unity3d",
env_config={
"file_name": "src/hivex/dev_environments/Hivex_WindFarmControl_win_1/Hivex_WindfarmControl.exe",
"episode_horizon": 1000,
},
)
.framework("torch")
# For running in editor, force to use just one Worker (we only have
# one Unity running)!
.rollouts(
num_rollout_workers=1,
rollout_fragment_length=200,
)
.training(
lr=0.0003,
lambda_=0.95,
gamma=0.99,
sgd_minibatch_size=256,
# this dictates how this is logged to wandb
# steps on X axis in wandb will be train_batch_size * agent count
train_batch_size=8000,
num_sgd_iter=3,
clip_param=0.2,
model={"fcnet_hiddens": [64, 64]},
)
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
stop = {
"timesteps_total": 500000,
}
wandb_callback = WandbLoggerCallback(
project="WFC_test_1",
api_key=WANDB_API_KEY,
upload_checkpoints=True,
save_checkpoints=True,
group=f"Test",
log_config=True,
)
# Run the experiment.
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop=stop,
verbose=1,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=5,
checkpoint_at_end=True,
),
callbacks=[wandb_callback],
),
).fit()
ray.shutdown()
wandb.finish()
Additionally I added to unity3d_env.py
observation and action specs like so:
@staticmethod
def get_policy_configs_for_game(
game_name: str,
) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
....
# The RLlib server must know about the Spaces that the Client will be
# using inside Unity3D, up-front.
obs_spaces = {
"WindFarmControl": Box(float("-inf"), float("inf"), (6,)),
# 3DBall.
"3DBall": Box(float("-inf"), float("inf"), (8,)),
...
}
action_spaces = {
"WindFarmControl": MultiDiscrete([3]),
# 3DBall.
"3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
...
}
...
→ Here is the problem: Expectation would to have a cumulative reward of roughly 32k, but instead we are at 14k. If we divide that by the number of agents (8), we are at 1750 reward per agent or mean reward, compared to the ML-Agent results, that is less than half. ML-Agents results show something around 4k mean reward.
Just for fun I ran 5mil instead of 500k steps to see where things converge using RLlib:
This is my unity3d_env.py in case this is relevant:
from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace
import logging
import numpy as np
import random
import time
from typing import Callable, Optional, Tuple
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
logger = logging.getLogger(__name__)
@PublicAPI
class Unity3DEnv(MultiAgentEnv):
"""A MultiAgentEnv representing a single Unity3D game instance.
For an example on how to use this Env with a running Unity3D editor
or with a compiled game, see:
`rllib/examples/unity3d_env_local.py`
For an example on how to use it inside a Unity game client, which
connects to an RLlib Policy server, see:
`rllib/examples/serving/unity3d_[client|server].py`
Supports all Unity3D (MLAgents) examples, multi- or single-agent and
gets converted automatically into an ExternalMultiAgentEnv, when used
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
"""
# Default base port when connecting directly to the Editor
_BASE_PORT_EDITOR = 5004
# Default base port when connecting to a compiled environment
_BASE_PORT_ENVIRONMENT = 5005
# The worker_id for each environment instance
_WORKER_ID = 0
def __init__(
self,
file_name: str = None,
port: Optional[int] = None,
seed: int = 0,
no_graphics: bool = False,
timeout_wait: int = 300,
episode_horizon: int = 1000,
):
"""Initializes a Unity3DEnv object.
Args:
file_name (Optional[str]): Name of the Unity game binary.
If None, will assume a locally running Unity3D editor
to be used, instead.
port (Optional[int]): Port number to connect to Unity environment.
seed: A random seed value to use for the Unity3D game.
no_graphics: Whether to run the Unity3D simulator in
no-graphics mode. Default: False.
timeout_wait: Time (in seconds) to wait for connection from
the Unity3D instance.
episode_horizon: A hard horizon to abide to. After at most
this many steps (per-agent episode `step()` calls), the
Unity3D game is reset and will start again (finishing the
multi-agent episode that the game represents).
Note: The game itself may contain its own episode length
limits, which are always obeyed (on top of this value here).
"""
# Skip env checking as the nature of the agent IDs depends on the game
# running in the connected Unity editor.
self._skip_env_checking = True
super().__init__()
if file_name is None:
print(
"No game binary provided, will use a running Unity editor "
"instead.\nMake sure you are pressing the Play (|>) button in "
"your editor to start."
)
import mlagents_envs
from mlagents_envs.environment import UnityEnvironment
# Try connecting to the Unity3D game instance. If a port is blocked
port_ = None
while True:
# Sleep for random time to allow for concurrent startup of many
# environments (num_workers >> 1). Otherwise, would lead to port
# conflicts sometimes.
if port_ is not None:
time.sleep(random.randint(1, 10))
port_ = port or (
self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR
)
# cache the worker_id and
# increase it for the next environment
worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
Unity3DEnv._WORKER_ID += 1
try:
self.unity_env = UnityEnvironment(
file_name=file_name,
worker_id=worker_id_,
base_port=port_,
seed=seed,
no_graphics=no_graphics,
timeout_wait=timeout_wait,
)
print("Created UnityEnvironment for port {}".format(port_ + worker_id_))
except mlagents_envs.exception.UnityWorkerInUseException:
pass
else:
break
# ML-Agents API version.
self.api_version = self.unity_env.API_VERSION.split(".")
self.api_version = [int(s) for s in self.api_version]
# Reset entire env every this number of step calls.
self.episode_horizon = episode_horizon
# Keep track of how many times we have called `step` so far.
self.episode_timesteps = 0
def step(
self, action_dict: MultiAgentDict
) -> Tuple[
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
]:
"""Performs one multi-agent step through the game.
Args:
action_dict: Multi-agent action dict with:
keys=agent identifier consisting of
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
[Agent index, a unique MLAgent-assigned index per single agent]
Returns:
tuple:
- obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
- rewards: Rewards dict matching `obs`.
- dones: Done dict with only an __all__ multi-agent entry in
it. __all__=True, if episode is done for all agents.
- infos: An (empty) info dict.
"""
from mlagents_envs.base_env import ActionTuple
# Set only the required actions (from the DecisionSteps) in Unity3D.
all_agents = []
for behavior_name in self.unity_env.behavior_specs:
# New ML-Agents API: Set all agents actions at the same time
# via an ActionTuple. Since API v1.4.0.
if self.api_version[0] > 1 or (
self.api_version[0] == 1 and self.api_version[1] >= 4
):
actions = []
for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id:
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
actions.append(action_dict[key])
if actions:
if actions[0].dtype == np.float32:
action_tuple = ActionTuple(continuous=np.array(actions))
else:
action_tuple = ActionTuple(discrete=np.array(actions))
self.unity_env.set_actions(behavior_name, action_tuple)
# Old behavior: Do not use an ActionTuple and set each agent's
# action individually.
else:
for agent_id in self.unity_env.get_steps(behavior_name)[
0
].agent_id_to_index.keys():
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
self.unity_env.set_action_for_agent(
behavior_name, agent_id, action_dict[key]
)
# Do the step.
self.unity_env.step()
obs, rewards, terminateds, truncateds, infos = self._get_step_results()
# Global horizon reached? -> Return __all__ truncated=True, so user
# can reset. Set all agents' individual `truncated` to True as well.
self.episode_timesteps += 1
if self.episode_timesteps > self.episode_horizon:
return (
obs,
rewards,
terminateds,
dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}),
infos,
)
return obs, rewards, terminateds, truncateds, infos
def reset(
self, *, seed=None, options=None
) -> Tuple[MultiAgentDict, MultiAgentDict]:
"""Resets the entire Unity3D scene (a single multi-agent episode)."""
self.episode_timesteps = 0
self.unity_env.reset()
obs, _, _, _, infos = self._get_step_results()
return obs, infos
def _get_step_results(self):
"""Collects those agents' obs/rewards that have to act in next `step`.
Returns:
Tuple:
obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
rewards: Rewards dict matching `obs`.
dones: Done dict with only an __all__ multi-agent entry in it.
__all__=True, if episode is done for all agents.
infos: An (empty) info dict.
"""
obs = {}
rewards = {}
infos = {}
for behavior_name in self.unity_env.behavior_specs:
decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name)
# Important: Only update those sub-envs that are currently
# available within _env_state.
# Loop through all envs ("agents") and fill in, whatever
# information we have.
for agent_id, idx in decision_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
os = tuple(o[idx] for o in decision_steps.obs)
os = os[0] if len(os) == 1 else os
obs[key] = os
rewards[key] = (
decision_steps.reward[idx] + decision_steps.group_reward[idx]
)
for agent_id, idx in terminal_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
# Only overwrite rewards (last reward in episode), b/c obs
# here is the last obs (which doesn't matter anyways).
# Unless key does not exist in obs.
if key not in obs:
os = tuple(o[idx] for o in terminal_steps.obs)
obs[key] = os = os[0] if len(os) == 1 else os
rewards[key] = (
terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
)
# Only use dones if all agents are done, then we should do a reset.
return obs, rewards, {"__all__": False}, {"__all__": False}, infos
@staticmethod
def get_policy_configs_for_game(
game_name: str,
) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
# The RLlib server must know about the Spaces that the Client will be
# using inside Unity3D, up-front.
obs_spaces = {
"WindFarmControl": Box(float("-inf"), float("inf"), (6,)),
# 3DBall.
"3DBall": Box(float("-inf"), float("inf"), (8,)),
# 3DBallHard.
"3DBallHard": Box(float("-inf"), float("inf"), (45,)),
# GridFoodCollector
"GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
# Pyramids.
"Pyramids": TupleSpace(
[
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (4,)),
]
),
# SoccerTwos.
"SoccerPlayer": TupleSpace(
[
Box(-1.0, 1.0, (264,)),
Box(-1.0, 1.0, (72,)),
]
),
# SoccerStrikersVsGoalie.
"Goalie": Box(float("-inf"), float("inf"), (738,)),
"Striker": TupleSpace(
[
Box(float("-inf"), float("inf"), (231,)),
Box(float("-inf"), float("inf"), (63,)),
]
),
# Sorter.
"Sorter": TupleSpace(
[
Box(
float("-inf"),
float("inf"),
(
20,
23,
),
),
Box(float("-inf"), float("inf"), (10,)),
Box(float("-inf"), float("inf"), (8,)),
]
),
# Tennis.
"Tennis": Box(float("-inf"), float("inf"), (27,)),
# VisualHallway.
"VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
# Walker.
"Walker": Box(float("-inf"), float("inf"), (212,)),
# FoodCollector.
"FoodCollector": TupleSpace(
[
Box(float("-inf"), float("inf"), (49,)),
Box(float("-inf"), float("inf"), (4,)),
]
),
}
action_spaces = {
"WindFarmControl": MultiDiscrete([3]),
# 3DBall.
"3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
# 3DBallHard.
"3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32),
# GridFoodCollector.
"GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
# Pyramids.
"Pyramids": MultiDiscrete([5]),
# SoccerStrikersVsGoalie.
"Goalie": MultiDiscrete([3, 3, 3]),
"Striker": MultiDiscrete([3, 3, 3]),
# SoccerTwos.
"SoccerPlayer": MultiDiscrete([3, 3, 3]),
# Sorter.
"Sorter": MultiDiscrete([3, 3, 3]),
# Tennis.
"Tennis": Box(-1.0, 1.0, (3,)),
# VisualHallway.
"VisualHallway": MultiDiscrete([5]),
# Walker.
"Walker": Box(-1.0, 1.0, (39,)),
# FoodCollector.
"FoodCollector": MultiDiscrete([3, 3, 3, 2]),
}
# Policies (Unity: "behaviors") and agent-to-policy mapping fns.
if game_name == "SoccerStrikersVsGoalie":
policies = {
"Goalie": PolicySpec(
observation_space=obs_spaces["Goalie"],
action_space=action_spaces["Goalie"],
),
"Striker": PolicySpec(
observation_space=obs_spaces["Striker"],
action_space=action_spaces["Striker"],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "Striker" if "Striker" in agent_id else "Goalie"
elif game_name == "SoccerTwos":
policies = {
"PurplePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
"BluePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
else:
policies = {
game_name: PolicySpec(
observation_space=obs_spaces[game_name],
action_space=action_spaces[game_name],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return game_name
return policies, policy_mapping_fn
pip-freeze
:
absl-py==2.1.0
aiosignal==1.3.1
appdirs==1.4.4
astunparse==1.6.3
attrs==23.2.0
black==24.4.2
cachetools==5.3.3
cattrs==1.5.0
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
dm-tree==0.1.8
docker-pycreds==0.4.0
exceptiongroup==1.2.1
Farama-Notifications==0.0.4
filelock==3.14.0
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.3.1
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.43
google-auth==2.29.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
GPUtil==1.4.0
grpcio==1.63.0
gymnasium==0.28.1
h5py==3.11.0
-e git+https://github.com/philippds/hivex@9133b0de4d05038a5b221961b1e385fa8618cbb2#egg=hivex
idna==3.7
imageio==2.34.1
importlib_metadata==7.1.0
importlib_resources==6.4.0
iniconfig==2.0.0
intel-openmp==2021.4.0
jax-jumpy==1.0.0
Jinja2==3.1.3
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
keras==2.13.1
lazy_loader==0.4
libclang==18.1.1
lz4==4.3.3
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mkl==2021.4.0
mlagents==0.28.0
mlagents-envs==0.28.0
mpmath==1.3.0
msgpack==1.0.8
mypy-extensions==1.0.0
networkx==3.1
numpy==1.24.3
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==24.0
pandas==2.0.3
pathspec==0.12.1
pillow==10.3.0
pkgutil_resolve_name==1.3.10
platformdirs==4.2.1
pluggy==1.5.0
protobuf==3.20.3
psutil==5.9.8
pyarrow==16.0.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==1.10.14
Pygments==2.17.2
pypiwin32==223
pytest==8.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2024.1
PyWavelets==1.4.1
pywin32==306
PyYAML==6.0.1
ray==2.10.0
referencing==0.35.1
requests==2.31.0
requests-oauthlib==2.0.0
rich==13.7.1
rpds-py==0.18.0
rsa==4.9
scikit-image==0.21.0
scipy==1.10.1
sentry-sdk==2.0.1
setproctitle==1.3.3
shellingham==1.5.4
six==1.16.0
smmap==5.0.1
sympy==1.12
tbb==2021.12.0
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.13.0
tensorflow-estimator==2.13.0
tensorflow-intel==2.13.0
tensorflow-io-gcs-filesystem==0.31.0
termcolor==2.4.0
tifffile==2023.7.10
tomli==2.0.1
torch==2.3.0
torchvision==0.18.0
typer==0.12.3
typing_extensions==4.5.0
tzdata==2024.1
urllib3==2.2.1
wandb==0.16.6
Werkzeug==3.0.2
wrapt==1.16.0
zipp==3.18.1
python -V
:
Python 3.8.1
Any help would be extremely valuable as I’m running against beginning of June deadline and I really want to publish this with RLlib. If I cannot fix this I will have to switch framework which I don’t want to do
Let me know if there is any more info I need to provide,
Thank you very much for your time!!!