How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I use TFModelV2 to customized my model, and it’s work fine on PPO training, but it cant work on appo training, I use the same config and dont use custom model it’s also work fine.
This is my custom model below:
from typing import Dict, List
from ray.rllib.utils.framework import TensorType
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gymnasium as gym
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.typing import ModelConfigDict
class AtariCNN(TFModelV2):
"""Atari CNN model."""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
"""Init."""
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self._image_encoder = keras.Sequential(
layers=[
layers.Conv2D(32, 8, 4, "valid", "channels_last"),
layers.ReLU(),
layers.Conv2D(64, 4, 2, "valid", "channels_last"),
layers.ReLU(),
layers.Conv2D(64, 3, 1, "valid", "channels_last"),
layers.ReLU(),
layers.Flatten(data_format="channels_last"),
],
name="obs",
)
self._torso_net = keras.Sequential(
layers=[
layers.Dense(512),
layers.ReLU(),
layers.Dense(256),
layers.ReLU(),
]
)
self._policy_net = keras.Sequential(
layers=[
layers.Dense(128),
layers.ReLU(),
layers.Dense(64),
layers.ReLU(),
layers.Dense(num_outputs),
]
)
self._value_net = keras.Sequential(
layers=[
layers.Dense(128),
layers.ReLU(),
layers.Dense(64),
layers.ReLU(),
layers.Dense(1),
]
)
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Forward."""
del seq_lens
# obs = tf.cast(input_dict["obs"], dtype=tf.float32) / 255.
obs = tf.cast(input_dict["obs"], dtype=tf.float32)
embedding = self._image_encoder(obs)
torso_out = self._torso_net(embedding)
output = self._policy_net(torso_out)
self._value = self._value_net(torso_out)
return output, state
def value_function(self) -> TensorType:
"""Value function."""
return tf.reshape(self._value, [-1])
And this is my APPOConfig:
def get_cnn_model_config(num_rollout_workers: int = 10) -> APPOConfig:
"""Get cnn model config."""
config = (
APPOConfig()
.environment(f"Atari/{ENV_NAME}", clip_rewards=True)
.framework("tf")
.resources(
num_gpus=1,
num_learner_workers=1,
num_gpus_per_learner_worker=1,
)
.rollouts(
num_rollout_workers=num_rollout_workers,
num_envs_per_worker=5,
batch_mode="truncate_episodes",
observation_filter="NoFilter",
)
.training(
lambda_=0.95,
use_kl_loss=True,
kl_coeff=0.5,
entropy_coeff=0.01,
train_batch_size=5000,
num_sgd_iter=10,
clip_param=0.1,
model={
"custom_model": "atari_cnn",
},
_enable_learner_api=False,
)
.rl_module(_enable_rl_module_api=False)
)
return config
And this is the problem I was encountered:
2023-12-25 17:39:12,601 WARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Action space: Discrete(6)
Observation space: Box(0, 255, (210, 160, 3), uint8)
Use cnn model.
2023-12-25 17:39:17,061 INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8267
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py:484: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
`UnifiedLogger` will be removed in Ray 2.7.
return UnifiedLogger(config, logdir, loggers=None)
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
self._loggers.append(cls(self.config, self.logdir, self.trial))
/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
self._loggers.append(cls(self.config, self.logdir, self.trial))
Traceback (most recent call last):
File "/home/depot/ray-rl/rllib/atari_appo_test.py", line 200, in <module>
algo = config.build()
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1100, in build
return algo_class(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo.py", line 276, in __init__
super().__init__(config, *args, **kwargs)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 517, in __init__
super().__init__(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
self.setup(copy.deepcopy(self.config))
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo.py", line 289, in setup
super().setup(config)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/impala/impala.py", line 615, in setup
super().setup(config)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 639, in setup
self.workers = WorkerSet(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 157, in __init__
self._setup(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 247, in _setup
self._local_worker = self._make_worker(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 925, in _make_worker
worker = cls(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 525, in __init__
self._update_policy_map(policy_dict=self.policy_dict)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1727, in _update_policy_map
self._build_policy_map(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1838, in _build_policy_map
new_policy = create_policy_for_framework(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 132, in create_policy_for_framework
return policy_class(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py", line 96, in __init__
base.__init__(
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py", line 83, in __init__
self.model = self.make_model()
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py", line 126, in make_model
return make_appo_models(self)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/algorithms/appo/utils.py", line 32, in make_appo_models
policy.model_variables = policy.model.variables()
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 93, in variables
return list(self.variables(as_dict=True).values())
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 86, in variables
return self._find_sub_modules("", self.__dict__)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 139, in _find_sub_modules
sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/ray/rllib/models/tf/tf_modelv2.py", line 110, in _find_sub_modules
for var in struct.variables:
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 1702, in variables
return self.weights
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2542, in weights
return self._dedup_weights(self._undeduplicated_weights)
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2547, in _undeduplicated_weights
self._assert_weights_created()
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/sequential.py", line 471, in _assert_weights_created
super(functional.Functional, self)._assert_weights_created() # pylint: disable=bad-super-call
File "/home/miniconda3/envs/rllib/lib/python3.9/site-packages/keras/engine/training.py", line 2736, in _assert_weights_created
raise ValueError(f'Weights for model {self.name} have not yet been '
ValueError: Weights for model obs have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.
What the problem? I try my best to find the mistake, but got nothing, anybody can help me?
Thanks very much.
Additional: ray version==2.8.0