Thank you! To get distributed TensorFlow training running with Ray Train, you should use the MultiWorkerMirroredStrategy and ensure your dataset and model creation are inside the strategy’s scope. However, with recent TensorFlow/Keras versions (especially Keras 3.x), you may encounter compatibility issues. The official Ray example below is known to work with TensorFlow 2.x and Keras 2.x, but may require setting os.environ["TF_USE_LEGACY_KERAS"] = "1" before importing TensorFlow if you are using Keras 3.x (see discussion).
Here is a minimal working example for distributed training on CIFAR-10 using Ray Train and TensorFlow:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1" # Only if using Keras 3.x
import ray
import tensorflow as tf
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
def build_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(32, 32, 3)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
def train_func(config):
strategy = tf.distribute.MultiWorkerMirroredStrategy()
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
y_train = y_train.astype("int64")
batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 2)
with strategy.scope():
model = build_model()
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config={"batch_size": 64, "epochs": 2},
scaling_config=ScalingConfig(num_workers=2, use_gpu=True)
)
if __name__ == "__main__":
ray.init(address="auto") # or ray.init() for local
trainer.fit()
Make sure all nodes have the same TensorFlow and Ray versions, and that you start Ray on all nodes before running the script. If you still encounter the PerReplica error, try downgrading to TensorFlow 2.x and Keras 2.x, as this is a known compatibility issue (see Ray docs).
Would you like more detail on debugging multi-node setup or dataset sharding?
Sources:
Hint: Mention @RunLLM in the post for followups.