Sharing big ML models using only Ray Core

How severe does this issue affect your experience of using Ray?

  • None: Sharing current approach in the hopes of constructive feedback

I am distributing a big language model using Ray Core. I am not using Ray Serve, because I need the Machine Learning Scientists that I’m collaborating with to iterate as quickly as possible. Which means directly editing/calling functions with a REST API abstraction.

My current methodology is:

  1. Download the model from S3 and cache it using a Ray Actor
  2. Pass model references from the Ray Actor to various tasks

Here’s the code for this approach:

from pathlib import Path
import tempfile

import boto3
import ray
import torch
from my_library import load_model_and_alphabet_local


@ray.remote
class Model:
    """
    Ray Actor that downloads model and keeps it in the Ray Plasma object store.

    Easiest way to only download the model once and share it between Ray Tasks.
    """

    def __init__(self, model):
        s3_client = boto3.client("s3")

        with tempfile.TemporaryDirectory() as tmp_dir:
            downloaded_model = f"{tmp_dir}/{model}.pt"

            s3_client.download_file(
                "pq-ml-models",
                f"models/fair-esm/{model}.pt",
                downloaded_model,
            )

            self.model, self.alphabet = load_model_and_alphabet_local(Path(downloaded_model))

        # disables dropout for deterministic results
        self.model.eval()

    def get_model(self):
        return self.model

    def get_batch_conv(self):
        return self.alphabet.get_batch_converter()


@ray.remote
def model_inference(model, batch_converter, sequence):
    """
    Use model and converter reference for forward pass.

    Sequences could be batched and output could be further manipulated if desired.
    """
    _, _, batch_tokens = batch_converter([("tmp", sequence)])
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)

    return results["logits"][0].cpu().numpy()


# connect to kuberay cluster
ray.util.connect()


sequences = ["MKTVRQERLKSI", "VRILERSKEPVSGA", "QLAEELSVSRQVIV", "QDIAYLRSLGYN", "IVATPRGYVLAGG"]
ray_model = Model.remote("my_fancy_model")
ref_model = ray_model.get_model.remote()
ref_batch = ray_model.get_batch_conv.remote()

all_res = [model_inference.remote(ref_model, ref_batch, seq) for seq in sequences]
print([ray.get(res) for res in all_res])

Ignoring that I’m using ray.get directly which is not desirable, how can I improve my current approach to this task? Is there some feature of Ray outside of Ray Core I should be using? I plan on using Ray Workflows soon to average the predictions of 5 different models together in the next iteration of this project.

At first glance, this approach looks okay to me! A couple suggestions:

  1. You can do ray.get(all_res) instead of getting each result one at a time. This will allow Ray to fetch the results in parallel to the driver node, versus one at a time. Regarding the linked discuss post, you should only switch to a chunked ray.get if the total result size is larger than your driver node’s memory capacity, and if you can modify your script to process the results incrementally instead of all at once (i.e., you should not print all of the results at once).
  2. Note that since you’re using an actor, the model will be cached both in the actor’s heap memory and in Ray’s object store. Every time you call ray_model.get_model.remote() on the actor, this will create a new copy in Ray’s object store. Your script as is will only create one copy of the model that gets shared among the tasks, but if you call ref_model = ray_model.get_model.remote() again, this will create a second copy of the model. You can avoid this by reusing the previous ref_model or by using a task instead (see below).
  3. In this case, since you do not need to modify the actor’s state after the constructor, you can actually use a task to cache the model in Ray’s object store instead of an actor. This is a bit more direct and it will help to improve resource utilization and fault tolerance (Ray will automatically re-download the model if the original copies are lost). Here’s an example:
@ray.remote(num_returns=2)
def load_model(model):
    s3_client = boto3.client("s3")
    with tempfile.TemporaryDirectory() as tmp_dir:
        downloaded_model = f"{tmp_dir}/{model}.pt"
        s3_client.download_file(
                "pq-ml-models",
                f"models/fair-esm/{model}.pt",
                downloaded_model,
            )

        model, alphabet = load_model_and_alphabet_local(Path(downloaded_model))
    # disables dropout for deterministic results
    model.eval()
    return model, alphabet.get_batch_converter()