Retry Task w/Different Resources

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.


I’m currently trying to implement an embarrassingly parallel feature extraction job using Ray. I have a large dataset (~10M items) where each item in the dataset is stored as a raw data file in S3. Feature extraction is performed using a simple python function that takes a file as its input and outputs a python object (in this case a list of dictionaries). I’ve implemented feature extraction on a single file as a task in ray and can naively run my job as follows:

def feat_extraction(s3_key):
     # step 1: read file from S3
     # step 2: extract list of feature dictionaries
     # step 3: return list of feature dictionaries
     return features

output = ray.get([feat_extraction.remote(s3_key) for s3_key in s3_keys])

This works fine when testing on a subset of the files, but the problem is that our dataset has a long tail in terms of the amount of memory that an individual feature extraction can require. Over 99% of feature extraction jobs require < 1 GB of memory, however a small fraction can require as much as ~10 GB of memory.

In order to prevent my job from becoming bottlenecked by the network – and to be efficient with my resource usage – I wanted to implement a solution like the following:

class Supervisor:
    def __init__(self, s3_keys):
        self.s3_keys = s3_keys

    def do_task(self, s3_key):
            return feat_extraction.options(resources={"small-node": 1}, max_retries=0).remote(s3_key)

        except (ray.exceptions.WorkerCrashedError, ray.exceptions.RayTaskError) as e:
            return feat_extraction.options(resources={"large-node": 1}).remote(s3_key)

    def work(self):
        return ray.get([self.do_task(s3_key) for s3_key in range(self.s3_keys)])

sup = Supervisor.options(resources={"actor-node": 1}).remote()
out = ray.get(

Here I’m using the resources field to schedule the tasks / actor on specific nodegroups in my Kubernetes cluster. "small-node" corresponds to a nodegroup of machines with enough memory to process the 99% of files that take < 1GB of memory, and "large-node" corresponds to a nodegroup of machines with much more memory that can handle the corner case(s) where up to 10 GB of memory may be required.

The problem is that the code I’ve implemented above does not do what I want. It seems that the RayTaskError just blows through my error-handling logic, and I’ve been unable to come up with a way to force my task(s) to be retried on the larger nodegroup when they fail. I don’t believe that my use-case is all that uncommon, so I’m wondering if there is some way that Ray can support the more fine-grained control of task retries that I’m seeking?

I’d be happy to provide more details if needed, and am hoping that there’s a solution for this as I’m really excited about switching my data science team’s workloads from PySpark to Ray!

Hi @mrusso, thank you for the great writeup! I think the issue that you’re running into here is that, if the feat_extraction task fails, that will be raised when you fetch the result (ray.get()), not when you submit the task (.remote()). .remote() immediately returns a future after the task has been submitted to the cluster, the error would never be raised there!

If you change your retry logic to resubmit at the block-and-fetch point (ray.get()), this should work as intended.

As an aside, have you considered using Ray Datasets for this? This kind of large-scale feature extraction on data sitting in a data lake aligns pretty well with our positioning!

We’re still working on our support for transparently handling large files and large data skew, both of which appear to be factors in your use case, but that’s on our short-term roadmap.

@Clark_Zinzow Thanks so much for your response!

Your solution makes total sense, and my only remaining question is if there’s a way I can figure out which specific tasks failed when calling ray.get() and only retry those? As a simple POC I tried running:

out = []
  out = ray.get([f.remote(idx) for idx in range(5)])

For a function where I designed it to cause an OOM error when idx == 2. Not too surprisingly, the output I got was:


Which means I would need to retry the entire list of tasks, instead of only the one for idx == 2.

Will I need to use the ray.wait() + while loop construction here (so that I might only need to retry a smaller subset of the tasks when ray.wait() hits an exception) or is there a better approach that’s commonly used?

Also I am very interested in trying out Ray datasets! I’m still very new to Ray and decided to stick with a more vanilla approach for now (i.e. just understanding / using tasks and actors), but I would definitely be interested in trying it out in the near future once I get something working for the short term!

1 Like

The typical pattern that we see is using ray.wait() to wait for any task to complete/fail and then using ray.get() on the completed/failed task, resubmitting the task if it failed. Something like this should work:

import ray

future_to_key = {f.remote(key): key for key in keys}
results = {}
while future_to_key:
    done, _ = ray.wait(list(future_to_key.keys()), num_returns=1)
    done = done[0]
    done_key = future_to_key.pop(done)
        result[done_key] = ray.get(done)
    except (ray.exceptions.RayTaskError, ray.exceptions.WorkerCrashedError) as e:
        future_to_key[f.remote(done_key)] = done_key

This is great, thanks so much for your help!