I have been playing around with Ray for many model training on Databricks. One of the nice to haves is that it completely integrates with the Managed MLflow setup in Databricks. I’ve successfully trained models on a per customer basis using the Ray Dataset API, something along the lines of ray_dataset.groupby('customer').map_groups(fn=model.train)
. However, I can’t seem to get the logging the MLflow to work, I’ve tried different approaches and in the end decided to break it down as far as possible.
In short, if I use the Ray Dataset API or Ray Core (i.e. @ray.remote) I get the following error:
RaySystemError: System error: Failed to unpickle serialized exception
traceback: Traceback (most recent call last):
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 46, in from_ray_exception
return pickle.loads(ray_exception.serialized_exception)
File "/databricks/python/lib/python3.10/site-packages/mlflow/exceptions.py", line 117, in __init__
error_code = json.get("error_code", ErrorCode.Name(INTERNAL_ERROR))
AttributeError: 'str' object has no attribute 'get'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
obj = self._deserialize_object(data, metadata, object_ref)
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/_private/serialization.py", line 275, in _deserialize_object
return RayError.from_bytes(obj)
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 40, in from_bytes
return RayError.from_ray_exception(ray_exception)
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 49, in from_ray_exception
raise RuntimeError(msg) from e
RuntimeError: Failed to unpickle serialized exception
When running these lines of code:
import ray
from ray.util.spark import setup_ray_cluster, MAX_NUM_WORKER_NODES, shutdown_ray_cluster
num_worker_nodes = None
setup_ray_cluster(
num_worker_nodes=(num_worker_nodes if num_worker_nodes is not None else MAX_NUM_WORKER_NODES),
num_cpus_per_node=8,
collect_log_to_path='/dbfs/ray/logs')
ray.init(runtime_env={'env_vars': {'DATABRICKS_HOST': databricks_host, 'DATABRICKS_TOKEN': databricks_token}})
import mlflow
@ray.remote
class FooModel():
def __init__(self, mlflow_experiment='/mlflow_experiment', mlflow_uri='mlflow_uri'):
self.mflow_uri = mlflow_uri
self.mlfow_experiment = mlflow_experiment
def train(self, number):
mlflow.set_tracking_uri(self.mflow_uri)
mlflow.set_experiment(self.mlfow_experiment)
with mlflow.start_run(run_name="foo"):
mlflow.log_metric("metric", number)
actor_handle = FooModel.remote()
object_ref = actor_handle.train.remote(42)
result = ray.get(object_ref)
I am using DBR 13.1 ML the exact environment is specified here with Ray 2.3.1 [default,data,tune]
. I’ve tried using the setup_mlflow()
integration from Ray AIR, this isn’t solving the issue (as far as I know it is also reserved for Tune/Trainers).
I’ve looked at different posts that are also discussing a similar serialization issue, but these haven’t solved the issue (some were hinting at a fundamental issue about the built-in pickle from Python).
How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.