I want to serving my model using PPOTrainer, but the PPOTrainer is too slow to initialize with custom resnet model. The PPOTrainer costs 20s to initialize. This is unacceptable for my online services, For my K8S deployment cluster, a pod crash will cause the service to be unavailable.
SO, there is any way to short the time of trainer init costs?
This is the code script…
import random
import time
import logging
import gym
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from MahjongZzMjEnv import ZzMjConfig
tf1, tf, tfv = try_import_tf()
class MaskedResidualNetwork(TFModelV2):
"""Custom resnet model for ppo algorithm"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
s_time = time.time()
super(MaskedResidualNetwork, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
custom_model_config = model_config.get('custom_model_config')
activation = get_activation_fn(
custom_model_config.get("conv_activation"))
filters = custom_model_config.get("conv_filters", None)
if filters is None:
print('you must set the "conv_filters" in model config!')
raise NotImplementedError
no_final_linear = custom_model_config.get('no_final_linear')
vf_share_layers = custom_model_config.get('vf_share_layers')
inputs = tf.keras.layers.Input(
shape=(4, len(ZzMjConfig.card_index) - 1, 9), name='observations'
)
last_layer = inputs
# whether the last layer is the output of a Flattened
self.last_layer_is_flattened = False
# build the action layers
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = self.residual_block(
last_layer,
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding='same',
data_format='channels_last',
name=f'resnet{i}'
)
out_size, kernel, stride = filters[-1]
# no final linear: Last Layer is a Conv2D and uses num_outputs.
if no_final_linear and num_outputs:
last_layer = tf.keras.layers.Conv2D(
num_outputs,
kernel,
strides=(stride, stride),
activation=activation,
padding='valid',
data_format='chennels_last',
name='conv_out'
)(last_layer)
conv_out = last_layer
# finish network normally, then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding='valid',
data_format='channels_last',
name=f'conv{i + 1}'
)(last_layer)
# num_outputs defined. Use that to create an exact `num_output`-sized (1,1)-Conv2D.
if num_outputs:
conv_out = tf.keras.layers.Conv2D(
num_outputs,
[1, 1],
activation=None,
padding='same',
data_format='channels_last',
name='conv_out'
)(last_layer)
# num_outputs not known -> Flatten, then set self.num_outputs to the resulting number of nodes.
else:
self.last_layer_is_flattened = True
conv_out = tf.keras.layers.Flatten(
data_format='channels_last'
)(last_layer)
self.num_outputs = conv_out.shape[1]
# build the value layers
if vf_share_layers:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2])
)(last_layer)
value_out = tf.keras.layers.Dense(
1,
name='value_out',
activation=None,
kernel_initializer=normc_initializer(0.01)
)(last_layer)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding='same',
data_format='channels_last',
name=f'conv_value_{i}'
)(last_layer)
out_size, kernel, stride = filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding='valid',
data_format='channels_last',
name=f'conv_value_{i + 1}'
)(last_layer)
last_layer = tf.keras.layers.Conv2D(
1,
(1, 1),
activation=None,
padding='same',
data_format='channels_last',
name='conv_value_out'
)(last_layer)
value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2])
)(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables)
print(f'resnet init cost: {time.time() - s_time}')
def forward(self, input_dict, state, seq_lens):
s_time = time.time()
# explicit cast to float32 needed in eager.
model_out, self._value_out = self.base_model(
tf.cast(input_dict['obs']['state'], tf.float32)
)
# our last layer is already flat
if self.last_layer_is_flattened:
raise ValueError
avail_actions = input_dict['obs']['avail_actions']
action_mask = input_dict['obs']['action_mask']
intent_vector = tf.expand_dims(model_out, 1)
action_logits = tf.reduce_sum(
avail_actions * intent_vector, axis=1)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
# last layer is a n x [1, 1] Conv2D -> Flatten
else:
model_out = tf.squeeze(model_out, axis=[1, 2])
avail_actions = input_dict['obs']['avail_actions']
action_mask = input_dict['obs']['action_mask']
intent_vector = tf.expand_dims(model_out, 1)
action_logits = tf.math.reduce_sum(
intent_vector, axis=1)
inf_mask = tf.math.maximum(
tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
print(f'forwrad cost: {time.time() - s_time}')
# # explicit cast to float32 needed in eager.
# model_out, self._value_out = self.base_model(
# tf.cast(input_dict['obs'], tf.float32)
# )
# # our last layer is already flat
# if self.last_layer_is_flattened:
# return model_out, state
# # last layer is a n x [1, 1] Conv2D -> Flatten
# else:
# return tf.squeeze(model_out, axis=[1, 2]), state
def value_function(self):
return tf.reshape(self._value_out, [-1])
def residual_block(self,
x,
filters=192,
kernel=3,
strides=(1, 1),
conv_shortcut=True,
activation='relu',
padding='same',
data_format='channels_last',
name='residual_block'):
"""A residual block"""
bn_axis = 3 if data_format == 'channels_last' else 1
# 4x12x13
if conv_shortcut:
shortcut = tf.keras.layers.Conv2D(
filters,
1,
strides=strides,
name=name + '_0_conv'
)(x)
shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn'
)(shortcut)
else:
shortcut = x
# 4x12x192
x = tf.keras.layers.Conv2D(
filters,
1,
strides=strides,
name=name + '_1_conv'
)(x)
x = tf.keras.layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn'
)(x)
x = tf.keras.layers.Activation('relu', name=name + '_1_relu')(x)
# 4x12x192
x = tf.keras.layers.Conv2D(
filters,
kernel,
padding='SAME',
name=name + '_2_conv'
)(x)
x = tf.keras.layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn'
)(x)
x = tf.keras.layers.Activation('relu', name=name + '_2_relu')(x)
# 2x10x192
x = tf.keras.layers.Conv2D(filters, 1, name=name + '_3_conv')(x)
x = tf.keras.layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn'
)(x)
# 2x10x192
x = tf.keras.layers.Add(name=name + '_add')([shortcut, x])
x = tf.keras.layers.Activation('relu', name=name + '_out')(x)
return x
class RlRayEnv(gym.Env):
def __init__(self, agent_list: list):
import time
s_time = time.time()
self.cop_agent = []
self.info = {}
self.observation_space = gym.spaces.Dict({'state': gym.spaces.Box(low=0, high=1, shape=([4, 34, 9]), dtype=np.uint8),
'avail_actions': gym.spaces.Box(low=0, high=1, shape=([71])),
'action_mask': gym.spaces.Box(low=0, high=1, shape=([71]))})
self.action_space = gym.spaces.Discrete(
len(ZzMjConfig.action_set))
print(f'rl env init cost: {time.time() - s_time}')
def step(self, action: int):
pass
def close(self):
gym.Env.close(self)
if __name__ == '__main__':
ray.init(local_mode=True, include_dashboard=False)
ModelCatalog.register_custom_model(
"MaskedResNet", MaskedResidualNetwork)
resnet_trainer_config = {
'num_workers': 0,
'model': {
'custom_model': 'MaskedResNet',
'custom_model_config': {'conv_activation': 'relu',
'conv_filters': [(128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1),
(128, 3, 1), (128, 3, 1), (128,
3, 1), (128, 3, 1), (128, 3, 1),
(128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1), (128, (4, 34), 1)],
'no_final_linear': False,
'vf_share_layers': False}
},
}
s_time = time.time()
trainer = ppo.PPOTrainer(
config=resnet_trainer_config,
env=RlRayEnv)
print(f'PPOTrainer init cost: {time.time() - s_time}s')