How severe does this issue affect your experience of using Ray?
- None: Sharing current approach in the hopes of constructive feedback
I am distributing a big language model using Ray Core. I am not using Ray Serve, because I need the Machine Learning Scientists that I’m collaborating with to iterate as quickly as possible. Which means directly editing/calling functions with a REST API abstraction.
My current methodology is:
- Download the model from S3 and cache it using a Ray Actor
- Pass model references from the Ray Actor to various tasks
Here’s the code for this approach:
from pathlib import Path
import tempfile
import boto3
import ray
import torch
from my_library import load_model_and_alphabet_local
@ray.remote
class Model:
"""
Ray Actor that downloads model and keeps it in the Ray Plasma object store.
Easiest way to only download the model once and share it between Ray Tasks.
"""
def __init__(self, model):
s3_client = boto3.client("s3")
with tempfile.TemporaryDirectory() as tmp_dir:
downloaded_model = f"{tmp_dir}/{model}.pt"
s3_client.download_file(
"pq-ml-models",
f"models/fair-esm/{model}.pt",
downloaded_model,
)
self.model, self.alphabet = load_model_and_alphabet_local(Path(downloaded_model))
# disables dropout for deterministic results
self.model.eval()
def get_model(self):
return self.model
def get_batch_conv(self):
return self.alphabet.get_batch_converter()
@ray.remote
def model_inference(model, batch_converter, sequence):
"""
Use model and converter reference for forward pass.
Sequences could be batched and output could be further manipulated if desired.
"""
_, _, batch_tokens = batch_converter([("tmp", sequence)])
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=False)
return results["logits"][0].cpu().numpy()
# connect to kuberay cluster
ray.util.connect()
sequences = ["MKTVRQERLKSI", "VRILERSKEPVSGA", "QLAEELSVSRQVIV", "QDIAYLRSLGYN", "IVATPRGYVLAGG"]
ray_model = Model.remote("my_fancy_model")
ref_model = ray_model.get_model.remote()
ref_batch = ray_model.get_batch_conv.remote()
all_res = [model_inference.remote(ref_model, ref_batch, seq) for seq in sequences]
print([ray.get(res) for res in all_res])
Ignoring that I’m using ray.get
directly which is not desirable, how can I improve my current approach to this task? Is there some feature of Ray outside of Ray Core I should be using? I plan on using Ray Workflows soon to average the predictions of 5 different models together in the next iteration of this project.