How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
System:
Python 3.9.16
Ray 2.3.0 and 3.0.0.dev0
Tensorflow 2.11.1
Issue may be related to this PR
Unfortunately, I’m unable to replicate this with one of the lightweight and fast standard toy environments although I tried. So instead I provide you with some Sampler_perf obtained with different settings/Algos:
Impala (train_batch_size=500, num_rollout_workers=2,rollout_fragment_length=50):
sampler_perf:
mean_action_processing_ms: 0.1270183308157854
mean_env_render_ms: 0.0
mean_env_wait_ms: 517.0239390965231
mean_inference_ms: 3.660824960339331
mean_raw_obs_processing_ms: 1.627967505159968
sampler_results:
connector_metrics:
ObsPreprocessorConnector_ms: 0.007319450378417969
StateBufferConnector_ms: 0.004112720489501953
ViewRequirementAgentConnector_ms: 0.21414756774902344
Impala (train_batch_size=800, num_rollout_workers=4,rollout_fragment_length=100):
- only empty lists, dicts and nans returned
PPO reference example given below:
sampler_perf:
mean_action_processing_ms: 0.13562084506648384
mean_env_render_ms: 0.0
mean_env_wait_ms: 333.2677821337623
mean_inference_ms: 3.2553278776970607
mean_raw_obs_processing_ms: 0.832363541727467
sampler_results:
connector_metrics:
StateBufferConnector_ms: 0.010395050048828125
ViewRequirementAgentConnector_ms: 0.19804835319519043
I run a rather slow and computationally heavy custom environment. I’ve been running this since Ray 2.1.0 and 2.0.0 before that although as a GYM version and used the Impala algorithm with no issues whatsoever. I’ve upgraded to Gymnasium and registered the environment and got it working in it’s own right. The environment is vision based and I use a custom model as given below with framestack = 4
The observation space is:
Box(low=0, high=255, shape=(72, 128, 3), dtype=np.uint8)
def conv_layer(depth, name):
return tf.keras.layers.Conv2D(
filters=depth, kernel_size=3, strides=1, padding="same", name=name
)
def residual_block(x, depth, prefix):
inputs = x
assert inputs.get_shape()[-1].value == depth
x = tf.keras.layers.ReLU()(x)
x = conv_layer(depth, name=prefix + "_conv0")(x)
x = tf.keras.layers.ReLU()(x)
x = conv_layer(depth, name=prefix + "_conv1")(x)
return x + inputs
def conv_sequence(x, depth, prefix):
x = conv_layer(depth, prefix + "_conv")(x)
x = tf.keras.layers.MaxPool3D(pool_size=3, strides=2, padding="same")(x) # 3D for multiframe
x = residual_block(x, depth, prefix=prefix + "_block0")
x = residual_block(x, depth, prefix=prefix + "_block1")
return x
class CustomModel(TFModelV2):
"""Deep residual network that produces logits for policy and value for value-function;
Based on architecture used in IMPALA paper:https://arxiv.org/abs/1802.01561"""
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
depths = [16, 32, 32]
inputs = tf.keras.layers.Input(shape=obs_space.shape, name="observations")
scaled_inputs = tf.cast(inputs, tf.float32) / 255.0
x = scaled_inputs
for i, depth in enumerate(depths):
x = conv_sequence(x, depth, prefix=f"seq{i}")
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dense(units=256, activation="relu", name="hidden")(x)
logits = tf.keras.layers.Dense(units=num_outputs, name="pi")(x)
value = tf.keras.layers.Dense(units=1, name="vf")(x)
self.base_model = tf.keras.Model(inputs, [logits, value])
def forward(self, input_dict, state, seq_lens):
# explicit cast to float32 needed in eager
obs = tf.cast(input_dict["obs"], tf.float32)
logits, self._value = self.base_model(obs)
return logits, state
def value_function(self):
return tf.reshape(self._value, [-1])
def import_from_h5(self, h5_file):
self.base_model.load_weights(h5_file)
When running RLLIB Impala in the latest versions I experience the following (strange…) issues:
1. Training and rollouts:
.rollouts(num_rollout_workers=4) with .training(train_batch_size=800,model={“custom_model”:“CustomCNN”}) and .rollouts( rollout_fragment_length=100) results in no metric output (episode_reward_mean etc). It does appear to be training as it shows policy loss etc.
Setting it to 2 workers or train_batch/rollout_fragment_length=500/50 respectively makes it work though and metrics are shown. Increasing num_rollout_workers it fails again as well as when increasing num_envs_per_worker to more than 1.
Obviously, I suspected the custom environment at first but as you see above it runs under certain conditions. Additionally, when changing to PPO I can push it pretty much to system limit with no problems like this:
algo = (
ppo.PPOConfig()
.training(train_batch_size=2400,model={"custom_model":"CustomCNN"})
.environment(env="xxxenv",env_config={"dummy_param":"foo"})
.rollouts(
num_rollout_workers=6,
num_envs_per_worker=4,
rollout_fragment_length=100,
remote_worker_envs=True,
remote_env_batch_wait_ms=10,
preprocessor_pref=None,
sampler_perf_stats_ema_coef=2/(200+1) # 200 ema
)
.resources(num_gpus=1,num_cpus_per_worker=5)
.fault_tolerance(recreate_failed_workers=True, restart_failed_sub_environments=True)
.build()
)
.rollouts(…,sampler_perf_stats_ema_coef=2/(200+1)) appears to have no effect. The result still appears to be just the mean of the list of episode_rewards
2. Framework
.framework(framework=“tf2”,eager_tracing=True) results in the following error (also without specifying eager_tracing):
Exception in thread Thread-18:
Traceback (most recent call last):
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/threading.py”, line 980, in _bootstrap_inner
self.run()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/execution/learner_thread.py”, line 74, in run
self.step()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/execution/learner_thread.py”, line 91, in step
multi_agent_results = self.local_worker.learn_on_batch(batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py”, line 1036, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 139, in func
return obj(self, *args, **kwargs)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 224, in learn_on_batch
return super(TracedEagerPolicy, self).learn_on_batch(samples)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/utils/threading.py”, line 24, in wrapper
return func(self, *a, **k)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy_v2.py”, line 628, in learn_on_batch
stats = self._learn_on_batch_helper(postprocessed_batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 97, in _func
return func(*eager_args, **eager_kwargs)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py”, line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/tensorflow/python/eager/execute.py”, line 52, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InternalError: Graph execution error:
Detected at node ‘StatefulPartitionedCall_32’ defined at (most recent call last):
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/threading.py”, line 937, in _bootstrap
self._bootstrap_inner()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/threading.py”, line 980, in _bootstrap_inner
self.run()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/execution/learner_thread.py”, line 74, in run
self.step()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/execution/learner_thread.py”, line 91, in step
multi_agent_results = self.local_worker.learn_on_batch(batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py”, line 1036, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 139, in func
return obj(self, *args, **kwargs)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 224, in learn_on_batch
return super(TracedEagerPolicy, self).learn_on_batch(samples)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/utils/threading.py”, line 24, in wrapper
return func(self, *a, **k)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy_v2.py”, line 628, in learn_on_batch
stats = self._learn_on_batch_helper(postprocessed_batch)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py”, line 97, in _func
return func(*eager_args, **eager_kwargs)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy_v2.py”, line 924, in _learn_on_batch_helper
self._apply_gradients_helper(grads_and_vars)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy_v2.py”, line 1007, in _apply_gradients_helper
o.apply_gradients(
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py”, line 1140, in apply_gradients
return super().apply_gradients(grads_and_vars, name=name)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py”, line 634, in apply_gradients
iteration = self._internal_apply_gradients(grads_and_vars)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py”, line 1166, in _internal_apply_gradients
return tf.internal.distribute.interim.maybe_merge_call(
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py”, line 1216, in _distributed_apply_gradients_fn
distribution.extended.update(
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py”, line 1211, in apply_grad_to_update_var
return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: ‘StatefulPartitionedCall_32’
libdevice not found at ./libdevice.10.bc
[[{{node StatefulPartitionedCall_32}}]] [Op:__inference__learn_on_batch_helper_11039]
Traceback (most recent call last):
File “/home/novelty/lupus_gymnasium_dev/rllib_basic_test/impala_test_5.py”, line 141, in
result = algo.train()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/tune/trainable/trainable.py”, line 384, in train
raise skipped from exception_cause(skipped)
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/tune/trainable/trainable.py”, line 381, in train
result = self.step()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py”, line 769, in step
results, train_iter_ctx = self._run_one_training_iteration()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py”, line 2754, in _run_one_training_iteration
results = self.training_step()
File “/home/novelty/miniconda3/envs/lupus_gymnasium/lib/python3.9/site-packages/ray/rllib/algorithms/impala/impala.py”, line 619, in training_step
raise RuntimeError(“The learner thread died while training!”)
RuntimeError: The learner thread died while training!
BR
Jorgen