Hi @arturn, would you please look into this? 
I also have a minimal runnable scrip in case needed:
# %% Imports
import os
import gymnasium as gym
import numpy as np
from abc import ABC
import torch
from torch import nn
from ray import tune, air
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune.registry import get_trainable_cls
from ray.tune.logger import pretty_print
from ray.rllib.utils import check_env
# %% Env class
class SimpleEnv(gym.Env):
def __init__(self, env_config={'env_name': 'simple_env'}):
self.env_name = env_config['env_name']
self.n_actions = 3
self.n_states = 5
self.action_space = gym.spaces.Discrete(self.n_actions)
self.observation_space = gym.spaces.Box(0.0, 1.0, shape=(self.n_states,), dtype=float)
def reset(self, *, seed=None, options=None):
observation = np.random.rand(1, self.n_states)[0]
self.timestep = 0
return observation, {}
def _update_obs(self, action):
observation = np.random.rand(1, self.n_states)[0]
return observation
def _execute_action(self, action):
next_observation = self._update_obs(action)
done = False if self.timestep <=3 else True
reward = 1 if done else 0
return next_observation, reward, done
def _get_info(self):
random_info_dict = {'random_info': 1} #np.random.randn()
info = {'agent_1': random_info_dict}
return info
def step(self, action):
self.timestep += 1
observation, reward, done = self._execute_action(action)
truncated = done
info = self._get_info() if done else {}
return observation, reward, done, truncated, info
def seed(self, seed: int = None):
self.np_random, seed = gym.utils.seeding.np_random(seed)
return [seed]
# %% Main
if __name__ == "__main__":
env_name = 'simple_env'
agent_name = 'DQN' # alpha_zero
learner_name = 'trainer' # trainer tunner random
num_iters = 1
num_rollout_workers = 1
env_config = {'env_name': env_name}
save_env_data_flag = True
save_agent_flag = True
load_agent_flag = False
tmp_current_dir = os.getcwd()
tmp_storage = os.path.join(tmp_current_dir, 'storage')
tmp_env_data_dir = os.path.join(tmp_storage, 'env_dict')
tmp_model_dir = os.path.join(tmp_storage, 'model')
if learner_name == 'random':
env = SimpleEnv(env_config=env_config)
check_env(env)
obs, _ = env.reset()
while True:
action = env.action_space.sample()
obs, rew, done, truncated, info = env.step(action)
if done:
print('Done!')
print(f'info: {info}')
break
else:
algo_cls = get_trainable_cls(agent_name)
param_space = (
algo_cls
.get_default_config()
.environment(SimpleEnv, env_config=env_config)
.framework('torch')
.rollouts(num_rollout_workers=num_rollout_workers)
.resources(num_gpus=1)
.training(model={"fcnet_hiddens": [64, 64]})
)
if save_env_data_flag:
param_space.output = tmp_env_data_dir
param_space.output_max_file_size = 5000000
if learner_name == 'trainer':
algo = param_space.build()
algo.output = tmp_env_data_dir
# if load_agent_flag:
# self.algo.restore(self.agent_config['chkpt_path_'])
# print("In trainer: The model loaded!")
# checkpoint_dir = ''
for n in range(num_iters):
print(f"---------- in trainer: episode: {n}")
result = algo.train()
print(pretty_print(result))
# checkpoint_dir = algo.save(tmp_scenario_dir)
# print("In trainer: The checkpoints saved!")
algo.stop()
elif learner_name == 'tunner':
stop = {"training_iteration": num_iters}
run_config = air.RunConfig(
stop=stop,
local_dir=tmp_storage,
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True,
checkpoint_frequency=1),
)
tuner = tune.Tuner(
agent_name,
run_config=run_config,
param_space=param_space,
)
# if load_agent_flag:
# tuner.restore(self.agent_config['chkpt_path'])
results = tuner.fit()
checkpoint_dir = results.get_best_result(
metric="episode_reward_mean",
mode="max").checkpoint._local_path
print(f"checkpoint_dir: {checkpoint_dir}")
Thanks!