Hi, I am new to RLlib. While it is a fantastic libary, I am struggled to find the right way to enforce the learner using float32 as dtype. Or someone can give a clue of how to check the actual torch dtype the learner using.
I am using pytorch. and training process is like:
from ray.rllib.algorithms import ppo
from ray.tune.registry import register_env
from ray.tune.logger import pretty_print
total_timesteps = 100000
config = (
ppo.PPOConfig()
.environment("myenv")
.resources(num_gpus=1)
.env_runners(
num_env_runners=15,
num_envs_per_env_runner=1,
num_cpus_per_env_runner=1,
num_gpus_per_env_runner=0,
observation_filter="NoFilter",
rollout_fragment_length="auto",
)
.learners(num_gpus_per_learner=1)
.framework("torch")
.training(
entropy_coeff=0.01,
vf_loss_coeff=0.1,
clip_param=0.1,
vf_clip_param=10.0,
num_sgd_iter=10,
kl_coeff=0.5,
lr=0.0001,
grad_clip=100,
sgd_minibatch_size=100,
train_batch_size=1000,
model={"vf_share_layers": True},
)
)
algo = config.build()
for i in range(total_timesteps):
res = algo.train()
# print(pretty_print(res))
print(f"training_iteration: {i}\ntimesteps_total: {res['timesteps_total']}\n\
episode_len_mean: {res['env_runners']['episode_len_mean']:.2f}\n\
episode_reward_mean: {res['env_runners']['episode_reward_mean']:.2f}\n")
if (i+1)%(total_timesteps//10) == 0:
checkpoint_dir = algo.save().checkpoint.path
print(f"Checkpoint saved in directory {checkpoint_dir}")