Tf2 error with LSTM but not with torch framework

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Hi, I’m working with ray 2.20.0 on Python 3.10.14 on windows 11.
I configurated my custom environment and a LSTM-PPO algorithm to train the policy in a "torch" framework and everything was well (until the moment when I tried to compute_single_action(), but this is another problem posted in Compute Action with LSTM - #5 by hermmanhender ). When I changed the framework "tf2" the following error made the actors die:

(RolloutWorker pid=15420) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=15420, ip=127.0.0.1, actor_id=44e5dadb4de141335f7da27d01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001B0151799F0>)
(RolloutWorker pid=15420)   File "python\ray\_raylet.pyx", line 1887, in ray._raylet.execute_task
(RolloutWorker pid=15420)   File "python\ray\_raylet.pyx", line 1828, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\_private\function_manager.py", line 691, in actor_method_executor
(RolloutWorker pid=15420)     return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span     
(RolloutWorker pid=15420)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 532, in __init__     
(RolloutWorker pid=15420)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span     
(RolloutWorker pid=15420)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1737, in _update_policy_map
(RolloutWorker pid=15420)     self._build_policy_map(
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span     
(RolloutWorker pid=15420)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1848, in _build_policy_map
(RolloutWorker pid=15420)     new_policy = create_policy_for_framework(
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\utils\policy.py", line 138, in create_policy_for_framework
(RolloutWorker pid=15420)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\policy\eager_tf_policy.py", line 167, in __init__        
(RolloutWorker pid=15420)     super(TracedEagerPolicy, self).__init__(*args, **kwargs)
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\algorithms\ppo\ppo_tf_policy.py", line 81, in __init__   
(RolloutWorker pid=15420)     base.__init__(
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 120, in __init__     
(RolloutWorker pid=15420)     self.model = self.make_model()
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 268, in make_model
(RolloutWorker pid=15420)     return ModelCatalog.get_model_v2(
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\models\catalog.py", line 799, in get_model_v2
(RolloutWorker pid=15420)     return wrapper(
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\ray\rllib\models\tf\recurrent_net.py", line 195, in __init__       
(RolloutWorker pid=15420)     mask=tf.sequence_mask(seq_in),
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler
(RolloutWorker pid=15420)     raise e.with_traceback(filtered_tb) from None
(RolloutWorker pid=15420)   File "c:\Users\grhen\anaconda3\envs\eprllib1-1-1\lib\site-packages\keras\src\backend\common\keras_tensor.py", line 91, in __tf_tensor__
(RolloutWorker pid=15420)     raise ValueError(
(RolloutWorker pid=15420) ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:
(RolloutWorker pid=15420)
(RolloutWorker pid=15420) ```
(RolloutWorker pid=15420) x = Input(...)
(RolloutWorker pid=15420) ...
(RolloutWorker pid=15420) tf_fn(x)  # Invalid.
(RolloutWorker pid=15420) ```
(RolloutWorker pid=15420)
(RolloutWorker pid=15420) What you should do instead is wrap `tf_fn` in a layer:
(RolloutWorker pid=15420)
(RolloutWorker pid=15420) ```
(RolloutWorker pid=15420) class MyLayer(Layer):
(RolloutWorker pid=15420)     def call(self, x):
(RolloutWorker pid=15420)         return tf_fn(x)
(RolloutWorker pid=15420)
(RolloutWorker pid=15420) x = MyLayer()(x)
(RolloutWorker pid=15420) ```

I don’t know if this is a bug, because using torch framework there is no problem at all.