Customized model training bug

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I use TFModelV2 to customized my model, and it’s work fine on PPO training, but it cant work on appo training, I use the same config and dont use custom model it’s also work fine.

This is my custom model below:

from typing import Dict, List
from ray.rllib.utils.framework import TensorType
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import gymnasium as gym

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.typing import ModelConfigDict


class AtariCNN(TFModelV2):
  """Atari CNN model."""

  def __init__(
    self,
    obs_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    num_outputs: int,
    model_config: ModelConfigDict,
    name: str,
  ):
    """Init."""
    super().__init__(obs_space, action_space, num_outputs, model_config, name)
    self._image_encoder = keras.Sequential(
      layers=[
        layers.Conv2D(32, 8, 4, "valid", "channels_last"),
        layers.ReLU(),
        layers.Conv2D(64, 4, 2, "valid", "channels_last"),
        layers.ReLU(),
        layers.Conv2D(64, 3, 1, "valid", "channels_last"),
        layers.ReLU(),
        layers.Flatten(data_format="channels_last"),
      ],
      name="obs",
    )
    self._torso_net = keras.Sequential(
      layers=[
        layers.Dense(512),
        layers.ReLU(),
        layers.Dense(256),
        layers.ReLU(),
      ]
    )
    self._policy_net = keras.Sequential(
      layers=[
        layers.Dense(128),
        layers.ReLU(),
        layers.Dense(64),
        layers.ReLU(),
        layers.Dense(num_outputs),
      ]
    )
    self._value_net = keras.Sequential(
      layers=[
        layers.Dense(128),
        layers.ReLU(),
        layers.Dense(64),
        layers.ReLU(),
        layers.Dense(1),
      ]
    )

  def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType],
              seq_lens: TensorType) -> (TensorType, List[TensorType]):
    """Forward."""
    del seq_lens
    # obs = tf.cast(input_dict["obs"], dtype=tf.float32) / 255.
    obs = tf.cast(input_dict["obs"], dtype=tf.float32)
    embedding = self._image_encoder(obs)
    torso_out = self._torso_net(embedding)
    output = self._policy_net(torso_out)
    self._value = self._value_net(torso_out)
    return output, state

  def value_function(self) -> TensorType:
    """Value function."""
    return tf.reshape(self._value, [-1])

And this is my APPOConfig:

def get_cnn_model_config(num_rollout_workers: int = 10) -> APPOConfig:
    """Get cnn model config."""
    config = (
        APPOConfig()
        .environment(f"Atari/{ENV_NAME}", clip_rewards=True)
        .framework("tf")
        .resources(
            num_gpus=1,
            num_learner_workers=1,
            num_gpus_per_learner_worker=1,
        )
        .rollouts(
            num_rollout_workers=num_rollout_workers,
            num_envs_per_worker=5,
            batch_mode="truncate_episodes",
            observation_filter="NoFilter",
        )
        .training(
            lambda_=0.95,
            use_kl_loss=True,
            kl_coeff=0.5,
            entropy_coeff=0.01,
            train_batch_size=5000,
            num_sgd_iter=10,
            clip_param=0.1,
            model={
                "custom_model": "atari_cnn",
            },            
            _enable_learner_api=False,
        )
        .rl_module(_enable_rl_module_api=False)
    )
    return config

And this is the problem I was encountered:

2023-12-25 17:39:12,601 WARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Action space: Discrete(6)
Observation space: Box(0, 255, (210, 160, 3), uint8)
Use cnn model.
2023-12-25 17:39:17,061 INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8267
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py:484: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
Traceback (most recent call last):
  File "/home/depot/ray-rl/rllib/atari_appo_test.py", line 200, in <module>
    algo = config.build()
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1100, in build
    return algo_class(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo.py", line 276, in __init__
    super().__init__(config, *args, **kwargs)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 517, in __init__
    super().__init__(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo.py", line 289, in setup
    super().setup(config)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/impala/impala.py", line 615, in setup
    super().setup(config)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 639, in setup
    self.workers = WorkerSet(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 157, in __init__
    self._setup(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 247, in _setup
    self._local_worker = self._make_worker(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 925, in _make_worker
    worker = cls(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 525, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1727, in _update_policy_map
    self._build_policy_map(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1838, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 132, in create_policy_for_framework
    return policy_class(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py", line 96, in __init__
    base.__init__(
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py", line 83, in __init__
    self.model = self.make_model()
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py", line 126, in make_model
    return make_appo_models(self)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/utils.py", line 32, in make_appo_models
    policy.model_variables = policy.model.variables()
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 93, in variables
    return list(self.variables(as_dict=True).values())
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 86, in variables
    return self._find_sub_modules("", self.__dict__)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 139, in _find_sub_modules
    sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 110, in _find_sub_modules
    for var in struct.variables:
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 1702, in variables
    return self.weights
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2542, in weights
    return self._dedup_weights(self._undeduplicated_weights)
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2547, in _undeduplicated_weights
    self._assert_weights_created()
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/sequential.py", line 471, in _assert_weights_created
    super(functional.Functional, self)._assert_weights_created()  # pylint: disable=bad-super-call
  File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2736, in _assert_weights_created
    raise ValueError(f'Weights for model {self.name} have not yet been '
ValueError: Weights for model obs have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

What the problem? I try my best to find the mistake, but got nothing, anybody can help me?
Thanks very much.

Additional: ray version==2.8.0

The Algorithm has been moved outside of RLlib
Clone locally GitHub - ray-project/ray: Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
Switch to appropriate branch and copy the algorithm in your project

2023-12-25 17:39:12,601 WARNING init.py:10 – PG has/have been moved to rllib_contrib and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all rllib_contrib algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See ray/rllib_contrib at master · ray-project/ray · GitHub for more information on the RLlib contrib effort.