PPOTrainer is too SLOW to initialize

I want to serving my model using PPOTrainer, but the PPOTrainer is too slow to initialize with custom resnet model. The PPOTrainer costs 20s to initialize. This is unacceptable for my online services, For my K8S deployment cluster, a pod crash will cause the service to be unavailable.

SO, there is any way to short the time of trainer init costs?

This is the code script…

import random
import time
import logging
import gym
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf

from MahjongZzMjEnv import ZzMjConfig

tf1, tf, tfv = try_import_tf()


class MaskedResidualNetwork(TFModelV2):
    """Custom resnet model for ppo algorithm"""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        s_time = time.time()
        super(MaskedResidualNetwork, self).__init__(obs_space, action_space,
                                                    num_outputs, model_config, name)
        custom_model_config = model_config.get('custom_model_config')
        activation = get_activation_fn(
            custom_model_config.get("conv_activation"))
        filters = custom_model_config.get("conv_filters", None)
        if filters is None:
            print('you must set the "conv_filters" in model config!')
            raise NotImplementedError
        no_final_linear = custom_model_config.get('no_final_linear')
        vf_share_layers = custom_model_config.get('vf_share_layers')

        inputs = tf.keras.layers.Input(
            shape=(4, len(ZzMjConfig.card_index) - 1, 9), name='observations'
        )
        last_layer = inputs

        # whether the last layer is the output of a Flattened
        self.last_layer_is_flattened = False

        # build the action layers
        for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
            last_layer = self.residual_block(
                last_layer,
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding='same',
                data_format='channels_last',
                name=f'resnet{i}'
            )

        out_size, kernel, stride = filters[-1]

        # no final linear: Last Layer is a Conv2D and uses num_outputs.
        if no_final_linear and num_outputs:
            last_layer = tf.keras.layers.Conv2D(
                num_outputs,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding='valid',
                data_format='chennels_last',
                name='conv_out'
            )(last_layer)
            conv_out = last_layer
        # finish network normally, then add another linear one of size `num_outputs`.
        else:
            last_layer = tf.keras.layers.Conv2D(
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding='valid',
                data_format='channels_last',
                name=f'conv{i + 1}'
            )(last_layer)

            # num_outputs defined. Use that to create an exact `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                conv_out = tf.keras.layers.Conv2D(
                    num_outputs,
                    [1, 1],
                    activation=None,
                    padding='same',
                    data_format='channels_last',
                    name='conv_out'
                )(last_layer)
            # num_outputs not known -> Flatten, then set self.num_outputs to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                conv_out = tf.keras.layers.Flatten(
                    data_format='channels_last'
                )(last_layer)
                self.num_outputs = conv_out.shape[1]

        # build the value layers
        if vf_share_layers:
            last_layer = tf.keras.layers.Lambda(
                lambda x: tf.squeeze(x, axis=[1, 2])
            )(last_layer)
            value_out = tf.keras.layers.Dense(
                1,
                name='value_out',
                activation=None,
                kernel_initializer=normc_initializer(0.01)
            )(last_layer)
        else:
            # build a parallel set of hidden layers for the value net
            last_layer = inputs
            for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
                last_layer = tf.keras.layers.Conv2D(
                    out_size,
                    kernel,
                    strides=(stride, stride),
                    activation=activation,
                    padding='same',
                    data_format='channels_last',
                    name=f'conv_value_{i}'
                )(last_layer)

            out_size, kernel, stride = filters[-1]
            last_layer = tf.keras.layers.Conv2D(
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding='valid',
                data_format='channels_last',
                name=f'conv_value_{i + 1}'
            )(last_layer)
            last_layer = tf.keras.layers.Conv2D(
                1,
                (1, 1),
                activation=None,
                padding='same',
                data_format='channels_last',
                name='conv_value_out'
            )(last_layer)
            value_out = tf.keras.layers.Lambda(
                lambda x: tf.squeeze(x, axis=[1, 2])
            )(last_layer)

        self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
        self.register_variables(self.base_model.variables)
        print(f'resnet init cost: {time.time() - s_time}')

    def forward(self, input_dict, state, seq_lens):
        s_time = time.time()
        # explicit cast to float32 needed in eager.
        model_out, self._value_out = self.base_model(
            tf.cast(input_dict['obs']['state'], tf.float32)
        )

        # our last layer is already flat
        if self.last_layer_is_flattened:
            raise ValueError
            avail_actions = input_dict['obs']['avail_actions']
            action_mask = input_dict['obs']['action_mask']
            intent_vector = tf.expand_dims(model_out, 1)
            action_logits = tf.reduce_sum(
                avail_actions * intent_vector, axis=1)
            inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
            return action_logits + inf_mask, state
        # last layer is a n x [1, 1] Conv2D -> Flatten
        else:
            model_out = tf.squeeze(model_out, axis=[1, 2])
            avail_actions = input_dict['obs']['avail_actions']
            action_mask = input_dict['obs']['action_mask']
            intent_vector = tf.expand_dims(model_out, 1)
            action_logits = tf.math.reduce_sum(
                intent_vector, axis=1)
            inf_mask = tf.math.maximum(
                tf.math.log(action_mask), tf.float32.min)

            return action_logits + inf_mask, state
        print(f'forwrad cost: {time.time() - s_time}')
        # # explicit cast to float32 needed in eager.
        # model_out, self._value_out = self.base_model(
        #     tf.cast(input_dict['obs'], tf.float32)
        # )
        # # our last layer is already flat
        # if self.last_layer_is_flattened:
        #     return model_out, state
        # # last layer is a n x [1, 1] Conv2D -> Flatten
        # else:
        #     return tf.squeeze(model_out, axis=[1, 2]), state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    def residual_block(self,
                       x,
                       filters=192,
                       kernel=3,
                       strides=(1, 1),
                       conv_shortcut=True,
                       activation='relu',
                       padding='same',
                       data_format='channels_last',
                       name='residual_block'):
        """A residual block"""

        bn_axis = 3 if data_format == 'channels_last' else 1
        # 4x12x13
        if conv_shortcut:
            shortcut = tf.keras.layers.Conv2D(
                filters,
                1,
                strides=strides,
                name=name + '_0_conv'
            )(x)
            shortcut = tf.keras.layers.BatchNormalization(
                axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn'
            )(shortcut)
        else:
            shortcut = x
        # 4x12x192

        x = tf.keras.layers.Conv2D(
            filters,
            1,
            strides=strides,
            name=name + '_1_conv'
        )(x)
        x = tf.keras.layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn'
        )(x)
        x = tf.keras.layers.Activation('relu', name=name + '_1_relu')(x)

        # 4x12x192
        x = tf.keras.layers.Conv2D(
            filters,
            kernel,
            padding='SAME',
            name=name + '_2_conv'
        )(x)
        x = tf.keras.layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn'
        )(x)
        x = tf.keras.layers.Activation('relu', name=name + '_2_relu')(x)

        # 2x10x192
        x = tf.keras.layers.Conv2D(filters, 1, name=name + '_3_conv')(x)
        x = tf.keras.layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn'
        )(x)

        # 2x10x192
        x = tf.keras.layers.Add(name=name + '_add')([shortcut, x])
        x = tf.keras.layers.Activation('relu', name=name + '_out')(x)

        return x




class RlRayEnv(gym.Env):
    def __init__(self, agent_list: list):
        import time
        s_time = time.time()
        self.cop_agent = [] 
        self.info = {} 
        self.observation_space = gym.spaces.Dict({'state': gym.spaces.Box(low=0, high=1, shape=([4, 34, 9]), dtype=np.uint8),
                                                  'avail_actions': gym.spaces.Box(low=0, high=1, shape=([71])),
                                                  'action_mask': gym.spaces.Box(low=0, high=1, shape=([71]))})

        self.action_space = gym.spaces.Discrete(
            len(ZzMjConfig.action_set)) 

        print(f'rl env init cost: {time.time() - s_time}')
   

    def step(self, action: int):
        pass

    def close(self):
        gym.Env.close(self)


if __name__ == '__main__':

    ray.init(local_mode=True, include_dashboard=False)
    ModelCatalog.register_custom_model(
        "MaskedResNet", MaskedResidualNetwork)
    resnet_trainer_config = {
        'num_workers': 0,
        'model': {
            'custom_model': 'MaskedResNet',
            'custom_model_config': {'conv_activation': 'relu',
                                    'conv_filters': [(128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1),
                                                     (128, 3, 1), (128, 3, 1), (128,
                                                                                3, 1), (128, 3, 1), (128, 3, 1),
                                                     (128, 3, 1), (128, 3, 1), (128, 3, 1), (128, 3, 1), (128, (4, 34), 1)],
                                    'no_final_linear': False,
                                    'vf_share_layers': False}
        },
    }

    s_time = time.time()
    trainer = ppo.PPOTrainer(
            config=resnet_trainer_config,
            env=RlRayEnv)
    print(f'PPOTrainer init cost: {time.time() - s_time}s')

@sven1977 Thanks for any help! :grinning: :grinning: