Training on multiple environment

Hi,
I was wondering if anybody has a suggestion/comment on the following problem:
I am trying to train an agent in multiple environment at the same time. In the other word, at the end of the episode I a need to switch between the environments.
Except some environmental parameters, the environments are similar, so I am going to switch between the environmental parameters when the agent “reset” the environments. It is fine when there is no parallel computing, but now that I am using rllib, I do not know if it cause an issue, e.g. one process looks at the old environment when another process looks at the new environment.
Any suggestions is highly appreciated

1 Like

Hey @Erfan_Asaadi , great question.
You could use one of our callbacks for this task. However, for this to work, you should have your environment class designed such that it can take all the different tasks you would like to simulate within a single gym.Env class (basically have a switch_task() method inside your env):

from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig

class MyCallbacks(DefaultCallbacks):
    def on_episode_created(
        self,
        *,
        worker: "RolloutWorker",
        base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        env_index: int,
        episode: Union[Episode, EpisodeV2],
        **kwargs
    ):
        # get the respective sub-environment that is about to be reset.
        sub_environment_to_change = worker.base_env.get_sub_environments()[env_index]
        # switch its task
        sub_environment_to_change.switch_task(task=...)


# Then, in your AlgorithmConfig object:
config = PPOConfig().callbacks(MyCallbacks)

Make sure your gym.Env subclass implements the switch_task method taking whatever arguments needed to switch between tasks.

import gym
from gym import spaces
import numpy as np
import test_exercises
import ray
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG

ray.init(ignore_reinit_error=True, log_to_driver=False)

class ChainEnv(gym.Env):
def init(self, env_config = None):
env_config = env_config or {}
print(env_config)
self.n = env_config.get(“n”, 6)
self.reward_ratio = .1
self.small_reward = env_config.get(“small”, 2 * self.reward_ratio) # payout for ‘backwards’ action
self.large_reward = env_config.get(“large”, 10 * self.reward_ratio) # payout at end of chain for ‘forwards’ action
self._horizon = self.n
self._counter = 0 # For terminating the episode
self._setup_spaces()
self.cnt = 0
self.task_num = 0

def _setup_spaces(self):
    self.action_space = spaces.Discrete(2)
    self.observation_space = spaces.Discrete(6)

def step(self, action):
    assert self.action_space.contains(action)
    if action == 1:  # 'backwards': go back to the beginning, get small reward
        reward = self.small_reward
        self.state = 0
    elif self.state < self.n - 1:  # 'forwards': go up along the chain
        reward = 0
        self.state += 1
    else:  # 'forwards': stay at the end of the chain, collect large reward
        reward = self.large_reward
    self._counter += 1
    done = self._counter >= self._horizon
    return self.state, reward, done, {}

def reset(self):
    self.cnt += 1
    self.state = 0
    self._counter = 0
    return self.state    

def switch_task(self):       
    if self.task_num % 2 == 0:
      self.reward_ratio = .1
    else:
      self.reward_ratio = 1
    self.task_num += 1
    self.small_reward = env_config.get("small", 2 * self.reward_ratio)  # payout for 'backwards' action
    self.large_reward = env_config.get("large", 10 * self.reward_ratio)  # payout at end of chain for 'forwards' action 

from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation.episode import Episode

class MyCallbacks(DefaultCallbacks):
def on_episode_created(
self, *, worker, base_env, policies, env_index, episode, **kwargs):
# get the respective sub-environment that is about to be reset.
sub_environment_to_change = worker.base_env.get_sub_environments()[env_index]
# switch its task
sub_environment_to_change.switch_task()

config = DQNConfig()
config = config.rollouts(num_rollout_workers=4).callbacks(MyCallbacks)
simple_trainer = config.environment(env=ChainEnv).build()
for _ in range(1):
A = simple_trainer.train()