MLflow with Ray on Databricks

I have been playing around with Ray for many model training on Databricks. One of the nice to haves is that it completely integrates with the Managed MLflow setup in Databricks. I’ve successfully trained models on a per customer basis using the Ray Dataset API, something along the lines of ray_dataset.groupby('customer').map_groups(fn=model.train). However, I can’t seem to get the logging the MLflow to work, I’ve tried different approaches and in the end decided to break it down as far as possible.

In short, if I use the Ray Dataset API or Ray Core (i.e. @ray.remote) I get the following error:

RaySystemError: System error: Failed to unpickle serialized exception
traceback: Traceback (most recent call last):
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 46, in from_ray_exception
    return pickle.loads(ray_exception.serialized_exception)
  File "/databricks/python/lib/python3.10/site-packages/mlflow/exceptions.py", line 117, in __init__
    error_code = json.get("error_code", ErrorCode.Name(INTERNAL_ERROR))
AttributeError: 'str' object has no attribute 'get'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/_private/serialization.py", line 275, in _deserialize_object
    return RayError.from_bytes(obj)
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 40, in from_bytes
    return RayError.from_ray_exception(ray_exception)
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/ray/exceptions.py", line 49, in from_ray_exception
    raise RuntimeError(msg) from e
RuntimeError: Failed to unpickle serialized exception

When running these lines of code:

import ray
from ray.util.spark import setup_ray_cluster, MAX_NUM_WORKER_NODES, shutdown_ray_cluster

num_worker_nodes = None

setup_ray_cluster( 
    num_worker_nodes=(num_worker_nodes if num_worker_nodes is not None else MAX_NUM_WORKER_NODES),
    num_cpus_per_node=8,
    collect_log_to_path='/dbfs/ray/logs')
ray.init(runtime_env={'env_vars': {'DATABRICKS_HOST': databricks_host, 'DATABRICKS_TOKEN': databricks_token}})
import mlflow

@ray.remote
class FooModel():

  def __init__(self, mlflow_experiment='/mlflow_experiment', mlflow_uri='mlflow_uri'):
    self.mflow_uri = mlflow_uri
    self.mlfow_experiment = mlflow_experiment
    
  def train(self, number):
    mlflow.set_tracking_uri(self.mflow_uri)
    mlflow.set_experiment(self.mlfow_experiment)

    with mlflow.start_run(run_name="foo"):
      mlflow.log_metric("metric", number)
actor_handle = FooModel.remote()
object_ref = actor_handle.train.remote(42)
result = ray.get(object_ref)

I am using DBR 13.1 ML the exact environment is specified here with Ray 2.3.1 [default,data,tune]. I’ve tried using the setup_mlflow() integration from Ray AIR, this isn’t solving the issue (as far as I know it is also reserved for Tune/Trainers).

I’ve looked at different posts that are also discussing a similar serialization issue, but these haven’t solved the issue (some were hinting at a fundamental issue about the built-in pickle from Python).

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

@cavriends thanks for posting here.
You might be better of if you do the MLflow tracking and logging within the training module. We have a callback integration.

This integration is on a Ray cluster training models and logging to a central MLflow server. With Ray on Databricks is still experimental or still early. Try with the MLflow call back in your training code.

HI @Jules_Damji ,

Thanks for the answer, I am not aware of any implementation of the integration with Ray Dataset (i.e. ray_dataset.groupby('customers').map_groups(...). However, the situation I’ve described in the post, should that work without relying on Databricks (i.e. remote MLflow server, Ray on Anyscale, etcetera)?

@cavriends I would not use ray_dataset.groupby(..) and use that to train models. If you want a model to specifically train on customer specific data, it’s better if you shard the data by customer, and provide sharded data to Ray Train to train model specific to that data set.

Here is example of training many models specific to say a geo location, training different models

  1. Training 1 Million ML Models in Record Time | Anyscale
  2. How to Conduct Many-Model Batch Training at Scale with Ray

Within Ray Train trainer code, you can use MLflowCallback to log to the MLflow server.

Hi @Jules_Damji ,

These posts were the inspiration to make it work on Ray (in the first one they provide ray_dataset.groupby(...) as an option as they state that Ray Train isn’t suitable for models into the multiple thousands, in the second they still refer to a Ray Core only implementation). I’ve tried to make it work with Ray Train / Tune, this does log to MLflow with the MLflowCallback, but incurs some overhead compared to Ray Core / Data. In addition, the logging that it provides is very minimal (limited number of parameters, no model or artifacts logged) and not (very) usable compared to what default MLflow logging on Databricks has to offer. Nonetheless, I’ll try it out a bit more and see if I can work around it. I understand that it is still early days for Ray on Databricks, but it looks promising! Thanks for helping me out!

@cavriends Right the MLflowCallback tracking is limited to parameters and metrics. Feel free to file an issue to extend MLflowCallback to extend functionality to artifacts and models too.

Yes, for hundreds and thousands simple statistical models, those two posts were demonstration of Ray Core/Data capabilities.

You welcome.

This is fixed in Make RestException pickleable by WeichenXu123 · Pull Request #10936 · mlflow/mlflow · GitHub