How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hi folks,
I am using Rllib to train an agent for my custom env whose observation_space needs to be a graph type data, similar to gData type from torch_geometric library.
I implemented a dummy env and a custom model, but it does not work. I wonder if anyone can have a look and let me how can I change it to make it work. Here is a runnable code:
# %%
import os
import numpy as np
import gymnasium as gym
from collections import OrderedDict
import torch
from torch import nn
import ray
from ray import tune, air
from ray.rllib.utils import check_env
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from ray.tune.registry import get_trainable_cls
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Data as gData
from torch_geometric.loader import DataLoader as gDataLoader
from callbacks import MyCallbacks
#%%
class SimpleGnnEnv(gym.Env):
def __init__(self, env_config={'env_name': 'SimpleGnnEnv'}):
self.env_name = env_config['env_name']
self.n_actions = 3
self.n_states = 5
self.action_space = gym.spaces.Discrete(self.n_actions)
self.num_nodes = 5
self.num_edges = 3
self.max_edges = 9
self.node_feature_dim = 7
edge = gym.spaces.Box(low=0,
high=self.num_nodes-1,
shape=(2,),
dtype=np.int16)
self.observation_space = gym.spaces.Dict({
'gnn_nodes': gym.spaces.Box(low=-5,
high=5,
shape=(self.num_nodes, self.node_feature_dim),
dtype=np.float16),
'gnn_edges': Repeated(edge, self.max_edges),
})
def reset(self, *, seed=None, options=None):
observation = OrderedDict({
'gnn_nodes': np.random.rand(self.num_nodes,self.node_feature_dim),
'gnn_edges': np.random.choice(self.num_nodes, (self.num_edges, 2))
})
observation = self.observation_space.sample()
self.timestep = 0
return observation, {}
def _update_obs(self, action):
observation = OrderedDict({
'gnn_nodes': np.random.rand(self.num_nodes,self.node_feature_dim),
'gnn_edges': np.random.choice(self.num_nodes, (self.num_edges, 2))
})
observation = self.observation_space.sample()
return observation
def _execute_action(self, action):
next_observation = self._update_obs(action)
done = False if self.timestep <=3 else True
reward = 1 if done else 0
return next_observation, reward, done
def _get_info(self, done):
random_info_dict = {'random_info': done} #np.random.randn()
info = {'agent_11': random_info_dict, 'timestep': self.timestep, 'done': done}
return info
def step(self, action):
self.timestep += 1
self.observation, self.reward, self.done = self._execute_action(action)
self.truncated = self.done
self.info = self._get_info(self.done) if self.done else {}
return self.observation, self.reward, self.done, self.truncated, self.info
def seed(self, seed: int = None):
self.np_random, seed = gym.utils.seeding.np_random(seed)
return [seed]
#%%
class TinyGnnNet(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
self.obs_space = obs_space
self.action_space = action_space
self.num_outputs = num_outputs
self.model_config = model_config
self.name = name
self.orig_space = getattr(obs_space, "original_space", obs_space)
self.p = 0.0
self.hidden_dim = 256
self.actor_hidden_dim = 256
self.critic_hidden_dim = 256
self.node_feature_dim = 7
self.feature_net_inp_layer = GCNConv(self.node_feature_dim, 2*self.hidden_dim)
self.feature_net_1 = nn.Sequential(
nn.ReLU(),
GCNConv(2*self.hidden_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
)
self.feature_net_out_layer = nn.Sequential(
nn.Linear(self.hidden_dim, self.actor_hidden_dim),
nn.ReLU()
)
self._actor_head = nn.Sequential(
nn.Linear(self.actor_hidden_dim, self.num_outputs))
self._critic_head = nn.Sequential(
nn.Linear(self.critic_hidden_dim, 1))
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs"]
gdata = self._convert_graph_data_dict_to_gdata_tensor(obs)
x, edge_index = gdata.x, gdata.edge_index
x = self.feature_net_inp_layer(x, edge_index)
x = self.feature_net_1(x, edge_index)
x = global_mean_pool(x, gdata.batch) # to get graph embeddings
x = self.feature_net_out_layer(x)
actor_latent = self._actor_head(x)
value = self._critic_head(x)
self._value = value.reshape(-1)
return actor_latent, state
def value_function(self):
return self._value#.flatten
def _convert_graph_data_dict_to_gdata_tensor(self, real_obs):
gnn_nodes = real_obs['gnn_nodes']
gnn_edges = real_obs['gnn_edges']
print(f"gnn_nodes: {gnn_nodes}")
print(f"gnn_edges: {gnn_edges}")
edge_index = torch.tensor(gnn_edges, dtype=torch.long)
edge_index.t().contiguous()
x = torch.tensor(gnn_nodes, dtype=torch.float)
y = torch.tensor([], dtype=torch.float)
gdata = gData(edge_index=edge_index, x=x, y=y)
return gdata
# %% Main
if __name__ == "__main__":
env_name = 'SimpleGnnEnv'
agent_name = 'PPO' # alpha_zero
learner_name = 'trainer' # trainer tunner random
num_iters = 1
num_rollout_workers = 1
env_config = {'env_name': env_name}
store_falg = False
save_env_data_flag = False
save_agent_flag = False
load_agent_flag = False
num_gpus = 0
local_mode_flag = True
current_dir = os.getcwd()
storage = os.path.join(current_dir, 'storage')
env_data_dir = os.path.join(storage, 'env_dict')
save_chkpt_to_dir = os.path.join(storage, 'models')
load_chkpt_from_dir = os.path.join(storage, agent_name)
if learner_name == 'random':
env = SimpleGnnEnv(env_config=env_config)
s, _ = env.reset()
a = env.action_space.sample()
s, _, _, _, _ = env.step(a)
print('Checking the env ...')
check_env(env)
while True:
action = env.action_space.sample()
obs, rew, done, truncated, info = env.step(action)
if done:
print('Done!')
print(f'info: {info}')
break
else:
print("not done yet")
print(f'info: {info}')
else:
ModelCatalog.register_custom_model('TinyGnnNet', TinyGnnNet)
ray.init(
ignore_reinit_error=True,
log_to_driver=False,
local_mode=local_mode_flag,
object_store_memory=10**8,
)
algo_cls = get_trainable_cls(agent_name)
param_space = (
algo_cls
.get_default_config()
.environment(SimpleGnnEnv, env_config=env_config)
.framework('torch')
.rollouts(num_rollout_workers=num_rollout_workers)
.resources(num_gpus=num_gpus)
.training(model={"custom_model": 'TinyGnnNet',
"vf_share_layers": True},
# train_batch_size=4,
# sgd_minibatch_size=2,
)
)
if save_env_data_flag:
param_space.callbacks_class = MyCallbacks
param_space.output = env_data_dir
param_space.output_max_file_size = 5000000
param_space.output_config = {
"format": "json", # json or parquet
# Directory to write data files.
"path": env_data_dir,
# Break samples into multiple files, each containing about this many records.
"max_num_samples_per_file": 100000,
}
if learner_name == 'trainer':
algo = param_space.build()
if save_env_data_flag:
algo.output = env_data_dir
if load_agent_flag:
algo.restore(load_chkpt_from_dir)
print("In trainer: The model loaded!")
for n in range(num_iters):
print(f"---------- in trainer: episode: {n}")
result = algo.train()
print(pretty_print(result))
if save_agent_flag:
checkpoint_dir = algo.save(save_chkpt_to_dir)
algo.stop()
elif learner_name == 'tunner':
stop = {"training_iteration": num_iters}
run_config = air.RunConfig(
stop=stop,
local_dir=storage,
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True,
checkpoint_frequency=1),
)
tuner = tune.Tuner(
agent_name,
run_config=run_config,
param_space=param_space,
)
if load_agent_flag:
tuner.restore(load_chkpt_from_dir, trainable="your_trainable")
results = tuner.fit()
ray.shutdown()