Hello people,
Is my toy example a correct and possible way how one can use the ExternalEnv API?
test.py:
from typing import Dict
import ray
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
from gym.spaces import Discrete, Box
from toy_external_env import ToyExternalEnv
def env_creator(config: Dict):
action_space = Discrete(2)
observation_space = Box(-10, 10, (4,))
return ToyExternalEnv(action_space, observation_space)
if __name__ == "__main__":
ray.init()
register_env("TEE", env_creator)
trainer = PPOTrainer(config= {"framework": "tf2"}, env="TEE")
results = trainer.train()
pretty_print(results)
print("Training end")
toy_external_env.py:
from typing import Any, Tuple
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.utils.annotations import override
import gym
import numpy as np
class ToyExternalEnv(ExternalEnv):
def __init__(self, action_space: gym.Space, observation_space: gym.Space,
max_concurrent: int = 100):
super().__init__(action_space, observation_space,
max_concurrent=max_concurrent)
# self.simulator = DES()
self.simulator = gym.make("CartPole-v0")
@override(ExternalEnv)
def run(self):
obs = self.simulator.reset()
self.simulator.render()
eid = self.start_episode()
while True:
action = self.get_action(eid, obs)
obs, reward, done, info = self.simulator.step(action)
self.simulator.render()
self.log_returns(eid, reward, info)
if done:
self.end_episode(eid, obs)
obs = self.simulator.reset()
eid = self.start_episode()