[Data] How to limit the number of retries from system failures for dataset.map?

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I tried running dataset.map with sys.exit in it, and it looks like the tasks are retried forever. Is there a way to limit such retries? This is not a true system failure, so if these can be considered as application failures that counts towards the max retry, that would work too.

Minimal example:

import sys
import ray

ray.data.from_items(range(10)).map(lambda x: sys.exit(0))

This never ends.

The example is artificial, but some libraries may have a bug in non-python code and fail without raising an exception. Ray shouldn’t keep retrying in that case.

@ray.remote
def fail():
  sys.exit(0)
ray.get(fail.remote())

This raises ray.exceptions.WorkerCrashedError: The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information. after a few retries. I would expect something similar for a ray dataset.

In case anyone stumbled upon this, I found a workaround to wrap a function as a class to make it an actor, and check num_restarts of the actor state. The following ends after 5 retries. This was inspired by was_current_actor_reconstructed implementation.

import sys

import ray
import ray.util.state

class Actor():
    def __init__(self):
        runtime_context = ray.runtime_context.get_runtime_context()
        actors = ray.util.state.list_actors(detail=True, filters=[('actor_id', '=', runtime_context.actor_id.hex())])
        if actors and actors[0]['num_restarts'] > 5:
            raise ValueError('Too many restarts')
        
    def __call__(self, x):
        sys.exit(1)
        

def main():
    ray.init()
    print(ray.data.from_items(range(10)).map(Actor, concurrency=1).take_all())

if __name__ == '__main__':
    main()

Setting max_retries explicitly also resolved the issue from infinite retries.

ray.data.from_items(range(10)).map(lambda x: sys.exit(1), max_retries=3).take_all()