How to improve performance of RayActors and TaskFunctions?

1. Severity of the issue: (select one)
High: Completely blocks me of using RAY.

2. Environment:

  • Ray version: 2.48.0
  • Python version: 3.11.11
  • OS: Ubuntu 24.04.3 LTS with 127 cores
  • GPU Infrastructure:
    • 1 GPU of NVIDIA-SMI with 44Gb of memory
    • Driver Version: 535.230.02
    • CUDA Version: 12.2

Hi there!

I have a time series pipeline that run 15 different models for each target using Parallel from joblib. 2 out of of 15 models were pre-trained foundation models (1GB each). My pipeline is able to run only 26 targets in parallel, because loading 26 foundation models would take all of the GPU Memory (44GB), I am also accounting the memory in prediction. My bottleneck is that I have other 101 CPU cores that remain idle.

So, I decided to use Ray Actors. This way, I could load each model only once and use all of my 127 cores. However, I did not have the time gain that I expected: Increasing in 4.9 times the number of cores decreased the running time in 2.5 times (From 13343 seconds to 5287 seconds). I expected that it would speedup my code in at least 4.5 times, taking a total time close to 3000 seconds.

I tried to increase the number of actors loaded simultaneously, and adjust the num_cpus and num_gpus using either round robbing strategy of ray ActorPool. Increasing the number of actors decreased the performance. Changing the num_cpus and num_gpus did not impact much as long as it respected the overall capacity of my machine.

I also noted that ray decreased the time of the others 13 models that rely on CPU (ex.: KNN) or that are light to fit on GPU (ex.: GPU). Besides that, running the experiment with 127 cores, they were all 100% occupied and the GPU had a maximum of 10% GPU usage and 5% of memory.

This is how I load the foundation models and create the actors:

num_cpus_actor = 0.006
num_gpus_actor = 0.006

@ray.remote(num_gpus=num_gpus_actor, num_cpus=num_cpus_actor)
class ChronosModelActor:
    def __init__(self, model_path: str, device: str, torch_dtype):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model = BaseChronosPipeline.from_pretrained(model_path, device_map=self.device, torch_dtype=torch_dtype)

    def predict(self, context_tensor, forecast_horizon: int):
        context_tensor = context_tensor.to(self.device)
        _, mean = self.model.predict_quantiles(context=context_tensor, prediction_length=forecast_horizon)
        return mean.detach().cpu()

@ray.remote(num_gpus=num_gpus_actor, num_cpus=num_cpus_actor)
class TimeMoEModelActor:
    def __init__(self, model_path: str, device: str, torch_dtype):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_path,
            device_map=self.device, 
            trust_remote_code=True,
            torch_dtype=torch_dtype,
        )
        self.model = self.model.to(self.device)

    def predict(self, context_tensor, forecast_horizon: int):
        context_tensor = context_tensor.to(self.device)
        output = self.model.generate(context_tensor, max_new_tokens=forecast_horizon)
        forecast = output[:, -forecast_horizon:]
        return forecast.cpu()

# Over all setup of process
N_CORES = 100

ray.init(num_cpus=N_CORES, num_gpus=1, include_dashboard=True)
path_chronos = os.path.join(BASE_PATH_TRAINED_MODELS, model_params.get("BaseChronosPipeline")["model_name"][0])
path_timemoe = os.path.join(BASE_PATH_TRAINED_MODELS, model_params.get("TimeMoE")["model_name"][0])
chronos_actors = ChronosModelActor.remote(path_chronos, device, torch.bfloat16)
timemoe_actors = TimeMoEModelActor.remote(path_timemoe, device, torch.bfloat16)

And then, I pass then to a task function to run in parallel for each target:


@ray.remote(num_cpus=0.9, num_gpus=0)
@ray_debug
def process_series(
    ray_actors,
    d,
    ...
):
     # Similar process to TimeMoe(other foundation model)
     model = ray_actors.get("chronos")
     # Make the predictions for all ids at once
     mean = ray.get(model.predict.remote(
            context_tensor=context_tensor,
            forecast_horizon=forecast_horizon,
        ))

     # this function does not return because I save the intermediate results into disk.


futures = []
for target in targets:
    futures.append(
        process_series.remote(
            ray_actors={"chronos": chronos_actors, 
                        "timemoe": timemoe_actors},
            target=target,
            ...
        )
    )
results = ray.get(futures)

Questions

  • Each target will be allocated to a specific core?
  • Does passing Actors into task function can cause the overhead?
  • How can I optimize it?

I saw that there is a project to decrease the overhead per actor that may be related to my task.

My studies:
I had a good overview on how the math allocation of gpu works from my previous question. I went through Running methods with actors is slower than running normal methods, Ray Actors crash course, Actors — Ray 2.49.1 and GitHub · Where software is built among others, but I could not find a use case similar to mine.

Hi @Guilherme_Parreira_d could you describe your workflow a bit more, for example where did the 0.006 come from for the number of cpus/gpus? Are you trying to run these actors on the gpu or cpu?

I think this was previously mentioned, but there are issues with using fractional resources for GPUs: [Core] Ray fractional GPU unable to be scheduled · Issue #52133 · ray-project/ray · GitHub
Also please take a look at this: AsyncIO / Concurrency for Actors — Ray 2.49.2 for concurrency with actors.

  • The 0.006 came from try and error to be honest. Empirically it had the best result for me.
  • I am running these actors on GPU.

Running only these two models on GPU I could have set num_gpus_actor = .45 (leaving .1 of the GPU free); I did this, but it did not increase the speed of running these models, neither increase the GPU calculus memory usage.

The problem in here is that I have short time series (100 rows each), but many of them: 2700. Based on this, I realized that increasing the number of actors should help (Because the actor tasks start form a queue). So I tried 18 actors for each model with num_gpus_actor = .023, but it did not make the process faster (the GPU calculation was near by 5% of usage).

The math behind .023:
.023 of 44Gb of my GPU uses 1Gb of GPU. So, with 36 actors I would still have 8Gb free of my GPU.

Thanks for replying me!

Each target will be allocated to a specific core?

No. num_cpus is a logical scheduling resource; Ray doesn’t pin tasks to a specific physical core by default.

Does passing Actors into task function can cause the overhead?

Passing the actor handle is cheap. The real overhead is many small remote calls and ray.get sync points.

How can I optimize it?

First, a few quick high-level recommendations:

  • Avoid fractional num_gpus for model actors; give each model actor a whole GPU (num_gpus=1) and push multiple requests to it.

  • Batch multiple targets per predict call so the GPU has enough work; for generation models, increase max_new_tokens batching or pack sequences.

  • Make actor methods async and enable concurrency (e.g., max_concurrency) so actors can overlap I/O/CPU prep with GPU work.

  • Submit method calls directly to the GPU actors from a lightweight coordinator; avoid heavy CPU tasks that immediately ray.get after calling the actor.

  • Use ray.put for shared tensors and pass object refs to avoid repeated serialization.

  • Pipeline with ray.wait to keep the GPU queue non-empty; don’t block on all results at once.

The last two points will require some refactoring of the code. I will share below to give some idea of what I mean (please replace the model imports and paths with yours). The key changes are:

  • Give each model actor a whole GPU (num_gpus=1) and allow high logical concurrency.

  • Add predict_batch to run many targets per forward pass (higher GPU utilization).

  • Use ray.put to share tensors by reference, and pipeline with ray.wait.

  • Call actors directly from the driver (no heavy per-target CPU tasks that immediately block on ray.get).

# ---- GPU actors (whole-GPU, batched inference) ----

@ray.remote(num_gpus=1, num_cpus=0.1, max_concurrency=64)
class ChronosModelActor:
    def __init__(self, model_path: str, torch_dtype: torch.dtype):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.set_num_threads(1)  # keep CPU side overhead low
        self.device = device
        self.model = BaseChronosPipeline.from_pretrained(
            model_path, device_map=self.device, torch_dtype=torch_dtype
        )

    def predict_batch(self, context_refs: List[ray.ObjectRef], forecast_horizon: int) -> List[Tensor]:
        # Fetch batch from object store once
        contexts: List[Tensor] = ray.get(context_refs)
        if not contexts:
            return []

        # Assumes all contexts have the same shape. If not, pad/pack as needed.
        batch_ctx = torch.stack([c.to(self.device, non_blocking=True) for c in contexts], dim=0)

        with torch.inference_mode():
            _, mean = self.model.predict_quantiles(
                context=batch_ctx, prediction_length=forecast_horizon
            )
            # mean shape: [B, horizon]; split to per-item tensors on CPU
            out = mean.detach().to("cpu")
            return [out[i] for i in range(out.shape[0])]

    # Optional single-item API for compatibility
    def predict(self, context_ref: ray.ObjectRef, forecast_horizon: int) -> Tensor:
        return self.predict_batch([context_ref], forecast_horizon)[0]


@ray.remote(num_gpus=1, num_cpus=0.1, max_concurrency=64)
class TimeMoEModelActor:
    def __init__(self, model_path: str, torch_dtype: torch.dtype):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.set_num_threads(1)
        self.device = device
        self.model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_path,
            device_map=self.device,
            trust_remote_code=True,
            torch_dtype=torch_dtype,
        ).to(self.device)

    def predict_batch(self, context_refs: List[ray.ObjectRef], forecast_horizon: int) -> List[Tensor]:
        contexts: List[Tensor] = ray.get(context_refs)
        if not contexts:
            return []

        # Assumes equal seq lengths; if not, pad to the max length in batch.
        batch_ctx = torch.stack([c.to(self.device, non_blocking=True) for c in contexts], dim=0)

        with torch.inference_mode():
            output = self.model.generate(batch_ctx, max_new_tokens=forecast_horizon)
            # Take last forecast_horizon tokens
            forecast = output[:, -forecast_horizon:].to("cpu")
            return [forecast[i] for i in range(forecast.shape[0])]

    def predict(self, context_ref: ray.ObjectRef, forecast_horizon: int) -> Tensor:
        return self.predict_batch([context_ref], forecast_horizon)[0]


# ---- Driver orchestration: put once, batch, and pipeline ----

def run_inference(
    chronos_model_path: str,
    timemoe_model_path: str,
    forecast_horizon: int,
    contexts: List[Tensor],
    batch_size: int = 64,
):
    # Initialize Ray; let Ray detect CPUs/GPUs. Adjust if you need to cap.
    ray.init(include_dashboard=True)

    chronos = ChronosModelActor.remote(chronos_model_path, torch.bfloat16)
    timemoe = TimeMoEModelActor.remote(timemoe_model_path, torch.bfloat16)

    # Put all contexts once to avoid repeated serialization
    context_refs = [ray.put(ctx) for ctx in contexts]

    # Create batches
    def batches(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i : i + n]

    # Submit batched inference; keep queues non-empty
    futures = []
    for batch_refs in batches(context_refs, batch_size):
        futures.append(chronos.predict_batch.remote(batch_refs, forecast_horizon))
        futures.append(timemoe.predict_batch.remote(batch_refs, forecast_horizon))

    # Pipeline completion with ray.wait
    chronos_results: List[Tensor] = []
    timemoe_results: List[Tensor] = []

    while futures:
        done, futures = ray.wait(futures, num_returns=1)
        batch_out = ray.get(done[0])

        # Heuristic demux based on object ref owner; if needed, track IDs when submitting
        # Simpler: maintain two separate future lists for clarity:
        # chronos_futs, timemoe_futs = [], []
        # and loop over both lists separately.
        if isinstance(batch_out, list) and batch_out and isinstance(batch_out[0], torch.Tensor):
            # If you split fut lists, you won't need this branch
            # Here we just aggregate; your production code should track which is which.
            chronos_results.extend(batch_out)  # or timemoe_results, depending on the source

    # If split into two fut lists, you'd collect each separately.
    return chronos_results, timemoe_results


if __name__ == "__main__":
    # Example usage: replace with your real data
    N = 2700
    CONTEXT_LEN = 100
    forecast_horizon = 24

    # Dummy CPU tensors; in practice, prepare your real tensors here.
    contexts = [torch.randint(0, 100, (CONTEXT_LEN,), dtype=torch.long) for _ in range(N)]

    base = os.environ.get("BASE_PATH_TRAINED_MODELS", "/models")
    path_chronos = os.path.join(base, "chronos_model_dir")
    path_timemoe = os.path.join(base, "timemoe_model_dir")

    chronos_out, timemoe_out = run_inference(
        path_chronos, path_timemoe, forecast_horizon, contexts, batch_size=64
    )
    print(len(chronos_out), len(timemoe_out))

Thank you so much for your valuable insights. It indeed is going to help me.

From your answer, I have the following doubt:

  • I have only 1 GPU that I can use. If setting 1 GPU for Chronos and 1 for TimeMoe, would not it require 2 GPUs in total? And then would not my code crash?
  • Or, the parameter max_concurrency would allow this scenario to happen? (I could not find the description of this argument from ray.remote, but it indeed accepts this argument. Could you share any reference?

Thank you

Meanwhile, I had to do a workaround and it worked. But it is not going to be a definitive solution, because it is not as powerful as Ray.