How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I am quite new to Reinforcement Learning and can’t understand it. I am unable to update configurations for the batch data using PPO.
I am using my custom-defined GYM environment, and want to train it using PPO and my external data which I’m loading in the form of torch DataLoader.
I am using Python 3.11 and Ray 2.40.0. Following is the relevant code:
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from torch.utils.data import DataLoader
train_dataset = MultimodalDataset(
csv_file=config.TRAIN_CSV_PATH, max_images=config.MAX_IMAGES_RL
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
# Define PPO configuration
ppo_config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="MultimodalSummarizationEnv", env_config=default_env_config)
.framework("torch")
.resources(num_gpus=0, num_cpus_per_worker=1)
)
# Create PPO trainer
trainer = ppo_config.build()
# Function to update worker environments
def update_env_config_and_reset(worker, new_env_config):
worker.foreach_env(lambda env: env.reset(env_config=new_env_config))
# Training loop
for batch_idx, batch in enumerate(train_loader):
# Prepare batch-specific env_config
new_env_config = {
# new data for the batch_idx
}
# Update and reset environments for all workers
trainer.workers.foreach_worker(
lambda worker: update_env_config_and_reset(worker, new_env_config)
)
# Train PPO
result = trainer.train()
ray.shutdown()
However, when running the code I get the error on foreach_worker as follows:
'function' object has no attribute 'foreach_worker'
Please help me identify where am I getting it wrong.
EDIT: Following is the reproducible MWE:
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from torch.utils.data import DataLoader
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
class CustomENv(gym.Env):
def __init__(self, device, max_steps=5):
super(CustomENv, self).__init__()
self.device = device
self.max_steps = max_steps
self.current_step = 0
self.data = None
# Define action and observation space
self.action_space = spaces.Discrete(2) # Select or discard
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(1,), # Simplified observation
dtype=np.float32,
)
def reset(self, new_data=None):
self.current_step = 0
self.done = False
if new_data is not None:
self.data = new_data
return self.data, {}
def step(self, action):
reward = 0
if not self.done:
reward = 1 if action == 1 else 0 # Dummy reward logic
self.current_step += 1
if self.current_step >= self.max_steps:
self.done = True
next_state = self.data if not self.done else np.zeros((1,))
return next_state, reward, self.done, False, {}
# Function to create RLlib environment
def make_rllib_env(env_config):
return CustomENv(device="cpu")
register_env("CustomENv", make_rllib_env)
# Dummy DataLoader for testing
class DummyDataset(torch.utils.data.Dataset):
def __len__(self):
return 10
def __getitem__(self, idx):
return np.array([idx], dtype=np.float32)
train_loader = DataLoader(DummyDataset(), batch_size=1, shuffle=True)
# Initialize Ray
ray.init(ignore_reinit_error=True)
# Define PPO configuration
ppo_config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="CustomENv")
.framework("torch")
.resources(num_gpus=0, num_cpus_per_worker=1)
)
# Create PPO trainer
trainer = ppo_config.build()
# Training loop
for batch_idx, batch in enumerate(train_loader):
print(f"Batch {batch_idx + 1}: {batch.numpy()}")
# Prepare batch-specific env_config
new_env_config = {"data": batch.numpy()}
# Update and reset environments for all workers
trainer.workers.foreach_worker(
lambda worker: worker.foreach_env(lambda env: env.reset(new_data=new_env_config["data"]))
)
result = trainer.train()
# Shutdown Ray
ray.shutdown()
This is the output error:
line 88, in
trainer.workers.foreach_worker(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: ‘function’ object has no attribute ‘foreach_worker’
The complete console output can be seen here, though I suspect it would be of any use: Google Doc