@kapibarek Thanks for posting.
The observation space above is a Discrete(3)
one and therefore contains int
, but your env returns for the observations list
. Furthermore, your environment does ot use the gymnasium
API interface, i.e. it still uses done
instead of terminated, truncated
(see Handling Time Limits - Gymnasium Documentation).
The below code runs for me:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
OBS_START = [0]
OBS_GROWL_LEFT = [1]
OBS_GROWL_RIGHT = [2]
OBS_MAP = {
OBS_START[0]: 'START',
OBS_GROWL_LEFT[0]: 'GROWL_LEFT',
OBS_GROWL_RIGHT[0]: 'GROWL_RIGHT',
}
ACTION_NONE = -1
ACTION_OPEN_LEFT = 0
ACTION_OPEN_RIGHT = 1
ACTION_LISTEN = 2
ACTION_MAP = {
ACTION_OPEN_LEFT: 'OPEN_LEFT',
ACTION_OPEN_RIGHT: 'OPEN_RIGHT',
ACTION_LISTEN: 'LISTEN',
ACTION_NONE: 'NONE',
}
class TigerEnv(gym.Env):
metadata = {
'render.modes': ['human'],
'render_modes': ['human']
}
def __init__(self, reward_tiger=-100, reward_gold=10, reward_listen=-1,
obs_accuracy=.85, max_steps_per_episode=100):
self.reward_tiger = reward_tiger
self.reward_gold = reward_gold
self.reward_listen = reward_listen
self.obs_accuracy = obs_accuracy
self.max_steps_per_episode = max_steps_per_episode
self.curr_episode = -1 # Set to -1 b/c reset() adds 1 to episode
self.action_episode_memory = []
self.observation_episode_memory = []
self.reward_episode_memory = []
self.curr_step = 0
self.reset()
# LISTEN, OPEN_LEFT, OPEN_RIGHT
self.action_space = spaces.Discrete(3)
# GROWL_LEFT, GROWL_RIGHT, START
self.observation_space = spaces.Discrete(3)
def step(self, action):
terminated = False
truncated = self.curr_step >= self.max_steps_per_episode
if truncated or terminated:
raise RuntimeError("Episode is done")
self.curr_step += 1
should_reset = self.take_action(action)
truncated = self.curr_step >= self.max_steps_per_episode
reward = self.get_reward()
self.action_episode_memory[self.curr_episode].append(action)
obs = self.get_obs()
self.observation_episode_memory[self.curr_episode].append(obs)
self.reward_episode_memory[self.curr_episode].append(reward)
if should_reset:
self.step_reset()
infos = {}
return obs, reward, terminated, truncated, infos
def reset(self, *, seed=None, options=None):
if seed is not None:
np.random.seed(seed)
self.curr_step = 0
self.curr_episode += 1
self.left_door_open = False
self.right_door_open = False
self.tiger_left = np.random.randint(0, 2)
self.tiger_right = 1 - self.tiger_left
initial_obs = OBS_START[0]
self.action_episode_memory.append([-1])
self.observation_episode_memory.append([initial_obs])
self.reward_episode_memory.append([0])
infos = {}
return initial_obs, infos
def render(self, mode='human'):
return
def close(self):
pass
def translate_obs(self, obs):
if obs[0] not in OBS_MAP:
raise ValueError('Invalid observation: '.format(obs))
else:
return OBS_MAP[obs[0]]
def translate_action(self, action):
return ACTION_MAP[action]
def take_action(self, action):
should_reset = False
if action == ACTION_OPEN_LEFT:
self.left_door_open = True
should_reset = True
elif action == ACTION_OPEN_RIGHT:
self.right_door_open = True
should_reset = True
elif action == ACTION_LISTEN:
pass
else:
raise ValueError('Invalid action ', action)
return should_reset
def get_reward(self):
if not (self.left_door_open or self.right_door_open):
return self.reward_listen
if self.left_door_open:
if self.tiger_left:
return self.reward_tiger
else:
return self.reward_gold
if self.right_door_open:
if self.tiger_right:
return self.reward_tiger
else:
return self.reward_gold
raise ValueError('Unreachable state reached.')
def get_obs(self):
last_action = self.action_episode_memory[self.curr_episode][-1]
if last_action != ACTION_LISTEN:
# Return accurate observation, but this won't be informative, since
# the tiger will be reset afterwards.
if self.tiger_left:
return OBS_GROWL_LEFT[0]
else:
return OBS_GROWL_RIGHT[0]
# Return accurate observation
if np.random.rand() < self.obs_accuracy:
if self.tiger_left:
return OBS_GROWL_LEFT[0]
else:
return OBS_GROWL_RIGHT[0]
# Return inaccurate observation
else:
if self.tiger_left:
return OBS_GROWL_RIGHT[0]
else:
return OBS_GROWL_LEFT[0]
def step_reset(self):
# Make sure doors are closed
self.left_door_open = False
self.right_door_open = False
self.tiger_left = np.random.randint(0, 2)
self.tiger_right = 1 - self.tiger_left
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune
tune.register_env('Tiger-v0', lambda config: TigerEnv())
config = PPOConfig()
config = config.training(gamma=0.99, lr=0.01, kl_coeff=0.3, train_batch_size=128)
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
import ray
ray.init(local_mode=True)
algo = config.build(env='Tiger-v0')
result = algo.train()