Here is a minimal example:
import gym
import numpy
import os
import ray
import ray.rllib
class Sim(ray.rllib.env.multi_agent_env.MultiAgentEnv):
def __init__(self, config):
self.policy = None
self.observation_space = gym.spaces.Box(numpy.float32("-inf"), numpy.float32("inf"), shape=(2,), dtype=numpy.float32)
self.action_space = gym.spaces.MultiDiscrete([3]*2)
def reset(self):
observations = {}
observations["robot0"] = [0,0]
observations["robot1"] = [0,1]
observations["robot2"] = [1,0]
observations["robot3"] = [1,1]
# seems on_episode_start() is called after reset() so I don't know how to get the policy on the first pass
if self.policy:
self.policy.compute_actions_from_input_dict({"obs":numpy.asarray([observations["robot0"], observations["robot1"], observations["robot2"], observations["robot3"]])})
print(self.policy.model.value_function())
return observations
def step(self, action):
observations = {}
observations["robot0"] = [0,0]
observations["robot1"] = [0,1]
observations["robot2"] = [1,0]
observations["robot3"] = [1,1]
rewards = {}
rewards["robot0"] = 0
rewards["robot1"] = 0
rewards["robot2"] = 0
rewards["robot3"] = 0
done = {"__all__": True}
return observations, rewards, done, {}
class Callbacks(ray.rllib.agents.callbacks.DefaultCallbacks):
def __init__(self, legacy_callbacks_dict=None):
super(Callbacks, self).__init__(legacy_callbacks_dict)
def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs):
env = base_env.get_unwrapped()[episode.env_id]
if not env.policy:
env.policy = policies["default_policy"]
SCRIPT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
ray.init()
ray.tune.run("APPO", config={"env": Sim, "env_config": {}, "callbacks": Callbacks, "model": {"fcnet_hiddens": [256,256], "fcnet_activation": "relu"}, "num_workers": 0, "num_envs_per_worker": 1, "num_gpus": 0.0}, local_dir=SCRIPT_DIRECTORY, name="Output")
This prints out stuff like:
(pid=146255) Tensor("Reshape:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_1:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_2:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_3:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_4:0", shape=(?,), dtype=float32)