Can I catch the original error in code outside train_func?

Hi, I want to catch the original error when trainer.fit() reports an error. Is it possible?

from typing import Any, Dict
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air import RunConfig, FailureConfig
from ray.train.torch.config import TorchConfig
from ray.train.base_trainer import TrainingFailedError
import transformers

def train_func(config: Dict[str, Any]):
    model = transformers.AutoModelForCausalLM.from_pretrained("model_not_existed")

failure_config = FailureConfig(max_failures=0)

trainer = TorchTrainer(
    train_func,
    train_loop_config={},
    scaling_config=ScalingConfig(num_workers=1),
    torch_config=TorchConfig(),
    run_config=RunConfig(failure_config=failure_config)
)
try:
    results = trainer.fit()
    print(results.error)    # This will not be executed 
except EnvironmentError as e:
    print("I want to do something specific for this type of error")
except TrainingFailedError as e:
    print("but it looks like I can only catch this exception")
    print("#####", e)
    # The output is : ##### The Ray Train run failed. Please inspect the previous error messages for a cause. After fixing the issue (assuming that the error is not caused by your own application logic, but rather an error such as OOM), you can restart the run from scratch or continue this run.
    # I can't do anything with this output

@KepingYan You mean you want to catch the first exeception in the training before it gets percolated up the stack?

cc: @justinvyu @matthewdeng Is that possible to catch the original cause. Seems it could be many or myriad exceptions up the training stack.

@KepingYan You can get this error via:

try:
    result = trainer.fit()
except TrainingFailedError as e:
    original_error = e.__cause__

It is unintuitive that result.error cannot really be used with trainer.fit, since it errors without giving access to the result object. Would it make more sense if this “raise on error” was configurable? Ex: There could be a pass on error option that lets the user handle the result.error themselves.

Thank you for your reply. In this example e.__cause__ is <class 'ray.exceptions.RayTaskError(OSError)'> , not EnvironmentError (original error) as I expected. If the original error cannot be catched, your suggestion is good choice. Thanks~

If you create the result with Result.from_path("/path/to/trial/dir"), then access the error from there, do you get the environment error you expect?

What is the OSError that you’re seeing? Where does that come up in the stack trace?

Yes, I can see environment error in the stack trace, but its type is RayTaskError(OSError).

<class 'types.RayTaskError(OSError)'>

ray::_Inner.train() (pid=97348, ip=10.1.0.133, actor_id=4dd08a0a937f0aa69f17695204000000, repr=TorchTrainer)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 43, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(OSError): ray::_RayTrainWorker__execute.get_next() (pid=97482, ip=10.1.0.133, actor_id=1a24c4e0db1a2be5e9e3e82904000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fdcac29a520>)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 118, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "inference/test_error.py", line 10, in train_func
    model = transformers.AutoModelForCausalLM.from_pretrained("model_not_existed")
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 461, in from_pretrained
    config, kwargs = AutoConfig.from_pretrained(
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/transformers/models/auto/configuration_auto.py", line 983, in from_pretrained
    config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/transformers/configuration_utils.py", line 617, in get_config_dict
    config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/transformers/configuration_utils.py", line 672, in _get_config_dict
    resolved_config_file = cached_file(
  File "/home/ykp/miniconda3/envs/LLM_release_2/lib/python3.8/site-packages/transformers/utils/hub.py", line 433, in cached_file
    raise EnvironmentError(
OSError: model_not_existed is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass `use_auth_token=True`.

And creating it from path is the same type.

>>> from ray.train import Result
>>> result_path = "/home/ykp/ray_results/TorchTrainer_2023-11-30_19-10-07/TorchTrainer_05017_00000_0_2023-11-30_19-10-07"
>>> result = Result.from_path(result_path)
>>> result
Result(
  error='RayTaskError(OSError)',
  metrics={},
  path='/home/ykp/ray_results/TorchTrainer_2023-11-30_19-10-07/TorchTrainer_05017_00000_0_2023-11-30_19-10-07',
  filesystem='local',
  checkpoint=None
)
>>> result.error
RayTaskError(OSError)(OSError("model_not_existed is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass `use_auth_token=True`."))
>>> type(result.error)
<class 'types.RayTaskError(OSError)'>