Hi mannyv,
So the thing is, I am using a custom SAC for RNN as well, since I’m working with Tensorflow but the example were only available in Torch, so I converted the code. I am aware the attention_net could not be run with the “SAC” provided by Ray since it’s aimed to be used with the feed forward architecture (correct me if I’m wrong).
Below is the custom attention_net.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 16:15:10 2021
@author: gael
"""
from gym.spaces import Box, Discrete, MultiDiscrete
import numpy as np
import gym
from typing import Any, Dict, Optional, Type, Union
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, SkipConnection
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
import tensorflow as tf
from tensorflow import keras
from ray.rllib.utils.tf_ops import one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
class PositionwiseFeedforward(keras.layers.Layer if tf else object):
"""A 2x linear layer with ReLU activation in between described in [1].
Each timestep coming from the attention head will be passed through this
layer separately.
"""
def __init__(self,
out_dim: int,
hidden_dim: int,
output_activation: Optional[Any] = None,
**kwargs):
super().__init__(**kwargs)
self._hidden_layer = tf.keras.layers.Dense(
hidden_dim,
activation=tf.nn.relu,
)
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation)
def call(self, inputs: TensorType, **kwargs) -> TensorType:
del kwargs
output = self._hidden_layer(inputs)
return self._output_layer(output)
class TransformerModel(RecurrentNetwork):
"""this is an implementation of GTrXL net by Deepmind"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name,
num_transformer_units: int=4, attention_dim:int=4,
num_heads: int=4, head_dim: int=32, position_wise_mlp_dim: int=32,
memory_inference:int =50, memory_training:int=50,
init_gru_gate_bias: float =2.0):
"""Initializes object.
Args:
num_transformer_units (int): The number of Transformer repeats to
use (denoted L in [2]).
attention_dim (int): The input and output dimensions of one
Transformer unit.
num_heads (int): The number of attention heads to use in parallel.
head_dim (int): The dimension of a single(!) attention head within
a multi-head attention unit. D
position_wise_mlp_dim (int): The dimension of the hidden layer
within the position-wise MLP (after the multi-head attention
block within one Transformer unit). This is the size of the
first of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attention_dim`.
"""
super().__init__(observation_space, action_space, num_outputs,
model_config, name)
custom_model_config = model_config["custom_model_config"]
self.num_transformer_units = custom_model_config["num_transformer_units"]
self.attention_dim = custom_model_config["attention_dim"]
self.num_heads = custom_model_config["num_heads"]
self.memory_inference = custom_model_config["memory_inference"]
self.memory_training = custom_model_config["memory_training"]
self.head_dim = custom_model_config["head_dim"]
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
input_layer = tf.keras.layers.Input(
shape=(None, self.obs_dim), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(None, self.attention_dim),
dtype=tf.float32,
name="memory_in_{}".format(i))
for i in range(self.num_transformer_units)
]
# Map observation dim to input/output transformer (attention) dim.
input_layer = tf.keras.layers.Input(
shape=(
None,
self.obs_dim,
), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(
None,
self.attention_dim,
),
dtype=tf.float32,
name="memory_in_{}".format(i))
for i in range(self.num_transformer_units)
]
# Map observation dim to input/output transformer (attention) dim.
E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
# Output, collected and concat'd to build the internal, tau-len
# Memory units used for additional contextual information.
memory_outs = [E_out]
# 2) Create L Transformer blocks according to [2].
for i in range(self.num_transformer_units):
# RelativeMultiHeadAttention part.
MHA_out = SkipConnection(
RelativeMultiHeadAttention(
out_dim=self.attention_dim,
num_heads=num_heads,
head_dim=head_dim,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gru_gate_bias),
name="mha_{}".format(i + 1))(
E_out, memory=memory_ins[i])
# Position-wise MLP part.
E_out = SkipConnection(
tf.keras.Sequential(
(tf.keras.layers.LayerNormalization(axis=-1),
PositionwiseFeedforward(
out_dim=self.attention_dim,
hidden_dim=position_wise_mlp_dim,
output_activation=tf.nn.relu))),
fan_in_layer=GRUGate(init_gru_gate_bias),
name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
# Output of position-wise MLP == E(l-1), which is concat'd
# to the current Mem block (M(l-1)) to yield E~(l-1), which is then
# used by the next transformer block.
memory_outs.append(E_out)
self._logits = None
self._value_out = None
self.trxl_model = tf.keras.Model(
inputs=[input_layer] + memory_ins,
outputs=[E_out] + memory_outs[:-1])
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(space=observation_space),
}
# Setup trajectory views (`memory-inference` x past memory outs).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
self.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)
@override(ModelV2)
def forward(self, input_dict, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
assert seq_lens is not None
# Add the time dim to observations.
B = tf.shape(seq_lens)[0]
observations = input_dict[SampleBatch.OBS]
shape = tf.shape(observations)
T = shape[0] // B
observations = tf.reshape(observations,
tf.concat([[-1, T], shape[1:]], axis=0))
all_out = self.trxl_model([observations] + state)
if self._logits is not None:
out = tf.reshape(all_out[0], [-1, self.num_outputs])
self._value_out = all_out[1]
memory_outs = all_out[2:]
else:
out = tf.reshape(all_out[0], [-1, self.attention_dim])
memory_outs = all_out[1:]
return out, [
tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs
]
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
@override(RecurrentNetwork)
def get_initial_state(self) -> List[np.ndarray]:
return []
@override(ModelV2)
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
And below is the script I used to run the test:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 09:22:15 2021
@author: gael
"""
import argparse
import os
import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.tune.registry import register_env
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from TransformerModel_Ray import TransformerModel
from Vanilla import TrXLNet
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune import CLIReporter
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.policy.policy import Policy
ModelCatalog.register_custom_model("TransformerModel", TransformerModel)
#tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
tune.register_env("RepeatInitialObsEnv", lambda c: RepeatInitialObsEnv())
SUPPORTED_ENVS = [
"RepeatAfterMeEnv", "RepeatInitialObsEnv", "LookAndPush",
"StatelessCartPole"
]
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def inc(self, n):
self.count += n
def get(self):
return self.count
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="RNNSAC",
help="The RLlib-registered algorithm to use.")
parser.add_argument("--env", type=str, default="sRepeatInitialObsEnv")
parser.add_argument("--num-cpus", type=int, default=2)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--stop-iters",
type=int,
default=100,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=100000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=90.0,
help="Reward at which we stop training.")
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
#c = Counter.remote()
#tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
#register_env("RepeatInitialObsEnv", lambda _: RepeatInitialObsEnv())
config = {
"env": "RepeatInitialObsEnv",
# This env_config is only used for the RepeatAfterMeEnv env.
"gamma": 0.99,
"twin_q": True,
"clip_actions": False,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
#"num_envs_per_worker": 20,
#"entropy_coeff": 0.001,
#"num_sgd_iter": 10,
#"vf_loss_coeff": 1e-5,
"render_env": True,
#"horizon": 1000,
#"#batch_mode": "complete_episodes",
#"#prioritized_replay": False,
#"#buffer_size": 100000,
#"#learning_starts": 1000,
#"train_batch_size": 480,
#"target_network_update_freq": 480,
"tau": 0.3,
"model": {
"max_seq_len": 10},
"Q_model": {
"custom_model": "TransformerModel",
"custom_model_config":{
"max_seq_len": 10,
"num_transformer_units": 4,
"attention_dim": 256,
"num_heads": 4,
"memory_inference": 100,
"memory_training":50,
"head_dim": 32,
"position_wise_mlp_dim": 32,
},
},
# Model options for the policy function (see `Q_model` above for details).
# The difference to `Q_model` above is that no action concat'ing is
# performed before the post_fcnet stack.
"policy_model": {
"custom_model": "TransformerModel",
"custom_model_config":{
"max_seq_len": 10,
"num_transformer_units": 4,
"attention_dim": 256,
"num_heads": 4,
"memory_inference": 100,
"memory_training":50,
"head_dim": 32,
"position_wise_mlp_dim": 32,
},
}
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward
}
if args.no_tune:
# manual training loop using PPO and manually keeping track of state
if args.run != "PPO":
raise ValueError("Only support --run PPO with --no-tune.")
ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
trainer = ppo.PPOTrainer(config=ppo_config, env=args.env)
# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = trainer.train()
print(pretty_print(result))
# stop training if the target train steps or reward are reached
if result["timesteps_total"] >= args.stop_timesteps or \
result["episode_reward_mean"] >= args.stop_reward:
break
# Run manual test loop (only for RepeatAfterMe env).
if args.env == "RepeatAfterMeEnv":
print("Finished training. Running manual test/inference loop.")
# prepare env
env = RepeatAfterMeEnv(config["env_config"])
obs = env.reset()
done = False
total_reward = 0
# start with all zeros as state
num_transformers = config["model"][
"attention_num_transformer_units"]
init_state = state = [
np.zeros([100, 32], np.float32)
for _ in range(num_transformers)
]
# run one iteration until done
print(f"RepeatAfterMeEnv with {config['env_config']}")
while not done:
action, state_out, _ = trainer.compute_single_action(
obs, state)
next_obs, reward, done, _ = env.step(action)
print(f"Obs: {obs}, Action: {action}, Reward: {reward}")
obs = next_obs
total_reward += reward
state = [
np.concatenate([state[i], [state_out[i]]], axis=0)[1:]
for i in range(num_transformers)
]
print(f"Total reward in test episode: {total_reward}")
# Run with Tune for auto env and trainer creation and TensorBoard.
else:
results = tune.run(args.run, config=config, stop=stop, verbose=2, progress_reporter = CLIReporter())
if args.as_test:
print("Checking if learning goals were achieved")
check_learning_achieved(results, args.stop_reward)
ray.shutdown()