System freezes during "minibatch learning loop"

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi folks,

I’ve changed my custom NN model and now using this one:

from typing import Optional, Sequence

from gym.spaces import Space, Box, Tuple

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.annotations import override

from ray.rllib.utils import try_import_tf
tf1, tf, tfv = try_import_tf(error=True)
tf1.enable_eager_execution()


class MinimalStateModel(TFModelV2):
    """ NOTE / TODO 
    - Use batch normalization layers
    - Try other pooling layers (e.g. max pooling)
    - Try different layer architecture and layer sizes
    """
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kwargs):
        self.original_space = obs_space.original_space if \
            hasattr(obs_space, "original_space") else obs_space
        self.processed_obs_space = self.original_space if \
            model_config.get("_disable_preprocessor_api") else obs_space
        super(MinimalStateModel, self).__init__(
            self.original_space, action_space, num_outputs, model_config, name)

        model_config.update(kwargs)

        concat_size = 0
        for key, space in self.original_space.items():
            if key == "1D" and isinstance(space, Tuple):
                self.mlp = MLP(
                    space,
                    hiddens=model_config.get("fcnet_hiddens", ()),
                    activation=model_config.get("fcnet_activation"),
                    name="mlp"
                )
                concat_size += self.mlp.num_outputs
            elif key == "2D" and isinstance(space, Tuple):
                self.cnn = CNN(
                    space,
                    model_config["conv_filters"],
                    conv_activation=model_config.get("conv_activation"),
                    conv_padding=model_config.get("conv_padding", "valid"),
                    pool_layers=model_config.get("pool_layers", ()),
                    pool_padding=model_config.get("pool_padding", "valid"),
                    post_fcnet_hiddens=model_config.get(
                        "post_cnn_fcnet_hiddens", ()),
                    post_fcnet_activation=model_config.get(
                        "post_cnn_fcnet_activation"),
                    name="cnn"
                )
                concat_size += self.cnn.num_outputs

        self.post_mlp = MLP(
            Box(float("-inf"), float("inf"), shape=(concat_size, )),
            hiddens=model_config.get("post_fcnet_hiddens", ()),
            activation=model_config.get("post_fcnet_activation"),
            name="post_mlp"
        )

        # Actions and value heads
        concatenated_input = tf.keras.layers.Input(
            shape=(self.post_mlp.num_outputs, )
        )
        assert num_outputs
        # action distribution head (logits)
        logits = tf.keras.layers.Dense(
            num_outputs,
            activation=None,
            kernel_initializer=normc_initializer(0.01),
            name="logits"
        )(concatenated_input)
        # value head
        value = tf.keras.layers.Dense(
            1,
            activation=None,
            kernel_initializer=normc_initializer(0.01),
            name="value_out"
        )(concatenated_input)
        self.logits_and_value = tf.keras.models.Model(
            inputs=concatenated_input, outputs=[logits, value],
            name="logits_and_value")
        self._value_out = None

        self.logits_and_value.summary()

    @override(TFModelV2)
    def forward(self, input_dict, state, seq_lens):
        """
        Forward pass through the NN model
        """
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(
                obs=input_dict[SampleBatch.OBS],
                obs_space=self.processed_obs_space,
                tensorlib="tf"
            )
        
        outs = []
        for key, obs in orig_obs.items():
            if key == "1D" and isinstance(obs, list):
                mlp_out = self.mlp(obs)
                outs.append(mlp_out)
            elif key == "2D" and isinstance(obs, list):
                cnn_out = self.cnn(obs)
                outs.append(cnn_out)
        
        out = tf.concat(outs, axis=-1)

        out = self.post_mlp(out)

        action_logits, value_out = self.logits_and_value(out)
        self._value_out = tf.reshape(value_out, [-1])

        action_mask = orig_obs["action_mask"]
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)

        return action_logits + inf_mask, state

    # @override(ModelV2)
    # def get_initial_state(self):
    #     raise NotImplementedError

    @override(ModelV2)
    def value_function(self):
        return self._value_out


class CNN(tf.keras.Model):
    """
    Custom CNN
    """
    def __init__(
        self,
        input_space: Space,
        conv_filters: Sequence[Sequence[int]],
        conv_activation: Optional[str] = None,
        conv_padding: str = "valid",
        pool_layers: Sequence[Sequence[int]] = (),
        pool_padding: str = "valid",
        post_fcnet_hiddens: Sequence[int] = (),
        post_fcnet_activation: Optional[str] = None,
        data_format: str = "channels_last",
        name: str = "",
        **kwargs
    ):
        super().__init__(name=name)

        conv_activation = get_activation_fn(conv_activation)
        post_fcnet_activation = get_activation_fn(post_fcnet_activation)

        self.data_format = data_format

        inputs_2d = []
        assert isinstance(input_space, Tuple)
        for box in input_space:
            assert isinstance(box, Box) and len(box.shape) == 3
            inputs_2d.append(tf.keras.layers.Input(shape=box.shape))
        # 2D input data may have multiple channels (i.e c x n x n with c>=1)
        input_2d = tf.keras.layers.concatenate(inputs_2d, axis=-1)
        last_layer = input_2d

        # build the conv and pooling layers
        for i, (out_channels, kernel, stride) in enumerate(conv_filters):
            last_layer = tf.keras.layers.Conv2D(
                out_channels,
                kernel,
                strides=stride,
                activation=conv_activation,
                padding=conv_padding,
                data_format="channels_last",
                name="conv_{}".format(i+1),
            )(last_layer)
            if pool_layers and i < len(pool_layers):
                (pool_size, stride) = pool_layers[i]
                last_layer = tf.keras.layers.MaxPooling2D(
                    pool_size=pool_size,
                    strides=stride,
                    padding=pool_padding,
                    data_format="channels_last",
                    name="max_pool_{}".format(i+1)
                )(last_layer)

        # flatten conv output into a vector
        last_layer = tf.keras.layers.Flatten(
            data_format="channels_last"
        )(last_layer)

        # add optional post-FC-stack after flattened conv output
        for i, out_units in enumerate(post_fcnet_hiddens):
            last_layer = tf.keras.layers.Dense(
                out_units,
                activation=post_fcnet_activation,
                kernel_initializer=normc_initializer(1.0),
                name="post_cnn_fcnet_{}".format(i+1)
            )(last_layer)

        cnn_out = last_layer
        self.num_outputs = cnn_out.shape[-1]

        self.base_model = tf.keras.Model(inputs_2d, cnn_out, name=name)
        self.base_model.summary()

    def call(self, inputs):
        inputs = list(inputs)
        if self.data_format == "channels_first":
            inputs_transposed = []
            for input in inputs:
                inputs_transposed.append(tf.transpose(input, [0, 2, 3, 1]))
            inputs = inputs_transposed
        cnn_out = self.base_model(inputs)
        return cnn_out


class MLP(tf.keras.Model):
    """
    Custom MLP
    """
    def __init__(
        self,
        input_space: Space,
        hiddens: Sequence[int] = (),
        activation: Optional[str] = None,
        name: str = ""
    ):
        super().__init__(name=name)

        activation = get_activation_fn(activation)

        input = None
        if isinstance(input_space, Tuple):
            inputs = []
            for box in input_space:
                assert isinstance(box, Box) and len(box.shape) == 1
                inputs.append(tf.keras.layers.Input(shape=box.shape))
            input = tf.keras.layers.concatenate(inputs, axis=-1)
        elif isinstance(input_space, Box) and len(input_space.shape) == 1:
            inputs = tf.keras.layers.Input(
                shape=input_space.shape
            )
            input = inputs
        last_layer = input

        # build the stack of FC layers
        for i, out_units in enumerate(hiddens):
            last_layer = tf.keras.layers.Dense(
                out_units,
                activation=activation,
                kernel_initializer=normc_initializer(1.0),
                name="fc_{}".format(i+1)
            )(last_layer)

        mlp_out = last_layer
        self.num_outputs = mlp_out.shape[-1]

        self.base_model = tf.keras.Model(inputs, mlp_out, name=name)
        self.base_model.summary()

    def call(self, inputs):
        mlp_out = self.base_model(inputs)
        return mlp_out

I would say it’s a "nested TFModelV2" where I use native Keras Models for CNN and MLP parts and finally pull off my action and value heads.
I collect rollouts of size 256 to build a train batch of size 2048 and then PPO algorithm starts to train on minibatches of size 128. But after at least one or more SGD iterations (“learn on minibatch calls”) my computer freezes and I can only pull the reset button. It seems that some Python process/thread requires (almost) all the cpu utilization before freezing.

Is there any issue with my new custom NN model? Any ideas what could cause this freezing?

PS: Before, I had used a custom TFModelV2 (dense layers + LSTM) w/o this issue, but even w/o nested native Keras Models.