Ray job running with flash_attn cost triple GPU memory than run direct

The ray job with 8 workers each of them has 1 GPU and 10 CPU.
When running pytorch(transformers) with deepspeed and flash_attn,it got a GPU out of memory. When I run train without ray, it runs correct. It seems cost triple GPU memory when init flash_attn when use ray.
I do not know where is wrong .
Can you give me any advice? My running log is as followed.


Training started with configuration:
╭───────────────────────────────────────────────╮
│ Training config                               │
├───────────────────────────────────────────────┤
│ train_loop_config/argv   ...me', 'debug-ray'] │
╰───────────────────────────────────────────────╯
e[36m(TrainTrainable pid=299096, ip=10.244.5.110)e[0m Trainable.setup took 36.999 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Setting up process group for: env:// [rank=0, world_size=8]
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m Started distributed worker processes: 
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304730) world_rank=0, local_rank=0, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304732) world_rank=1, local_rank=1, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304733) world_rank=2, local_rank=2, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304734) world_rank=3, local_rank=3, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304735) world_rank=4, local_rank=4, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304736) world_rank=5, local_rank=5, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304737) world_rank=6, local_rank=6, node_rank=0
e[36m(TorchTrainer pid=299096, ip=10.244.5.110)e[0m - (node_id=ea8881cb3df66a49a531bf3db2356ab46ba0a283e72cce08ef831c51, ip=10.244.5.110, pid=304738) world_rank=7, local_rank=7, node_rank=0
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m [2024-10-23 19:40:44,340] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m using untested triton version (2.1.0), only 1.0.0 is known to be compatible
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m [2024-10-23 19:40:44,341] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)e[32m [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)e[0m
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m [2024-10-23 19:41:14,023] [INFO] [comm.py:637:init_distributed] cdb=None
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m Current working directory: /
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m cpath /mnt/home/stang/workspace/ShareGPT4V/share4v/train/train_ray.py
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m npath /mnt/home/stang/workspace/ShareGPT4V
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATHe[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m e[93m [WARNING] e[0m using untested triton version (2.1.0), only 1.0.0 is known to be compatiblee[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m ModelArguments(model_name_or_path='/mnt/netdata/Team/AI/weights/huggingface/lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=False, vision_tower='/mnt/netdata/Team/AI/weights/huggingface/openai/clip-vit-large-patch14-336', tune_vision_tower=False, tune_vit_from_layer=12, tune_entire_model=True, mm_vision_select_layer=-2, pretrain_mm_mlp_adapter='/mnt/netdata/Team/AI/weights/huggingface/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin', mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_vision_select_feature='patch')
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m DataArguments(data_path='data/sharegpt4v/share-captioner_coco_lcs_sam_1246k_1107_fix_sam.json', lazy_preprocess=True, is_multimodal=False, image_folder='data', image_aspect_ratio='square', image_grid_pinpoints=None)
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m train_batch = 128, micro_batch=16,grad_acc=1,world_size=8
e[36m(RayTrainWorker pid=304738, ip=10.244.5.110)e[0m You are using a model of type llama to instantiate a model of type share4v. This is not supported for all configurations of models and can yield errors.
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m [2024-10-23 19:41:26,896] [INFO] [partition_parameters.py:345:__exit__] finished initializing model - num_params = 291, num_elems = 6.74B
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m [2024-10-23 19:41:14,021] [INFO] [comm.py:637:init_distributed] cdb=Nonee[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m Current working directory: /e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m cpath /mnt/home/stang/workspace/ShareGPT4V/share4v/train/train_ray.pye[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m npath /mnt/home/stang/workspace/ShareGPT4Ve[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m ModelArguments(model_name_or_path='/mnt/netdata/Team/AI/weights/huggingface/lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=False, vision_tower='/mnt/netdata/Team/AI/weights/huggingface/openai/clip-vit-large-patch14-336', tune_vision_tower=False, tune_vit_from_layer=12, tune_entire_model=True, mm_vision_select_layer=-2, pretrain_mm_mlp_adapter='/mnt/netdata/Team/AI/weights/huggingface/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin', mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_vision_select_feature='patch')e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m DataArguments(data_path='data/sharegpt4v/share-captioner_coco_lcs_sam_1246k_1107_fix_sam.json', lazy_preprocess=True, is_multimodal=False, image_folder='data', image_aspect_ratio='square', image_grid_pinpoints=None)e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m train_batch = 128, micro_batch=16,grad_acc=1,world_size=8e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m 
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
e[36m(RayTrainWorker pid=304734, ip=10.244.5.110)e[0m You are using a model of type llama to instantiate a model of type share4v. This is not supported for all configurations of models and can yield errors.e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304733, ip=10.244.5.110)e[0m 
Loading checkpoint shards:  50%|█████     | 1/2 [00:57<00:57, 57.08s/it]
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m 
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304738, ip=10.244.5.110)e[0m 
Loading checkpoint shards: 100%|██████████| 2/2 [01:14<00:00, 33.52s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [01:14<00:00, 37.06s/it]
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m 
Loading checkpoint shards:  50%|█████     | 1/2 [00:58<00:58, 58.19s/it]e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304738, ip=10.244.5.110)e[0m Load vision tower from /mnt/netdata/Team/AI/weights/huggingface/openai/clip-vit-large-patch14-336
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m train_batch = 128, micro_batch=16,grad_acc=1,world_size=8
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Load vision tower from /mnt/netdata/Team/AI/weights/huggingface/openai/clip-vit-large-patch14-336e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m [2024-10-23 19:42:51,214] [INFO] [partition_parameters.py:345:__exit__] finished initializing model - num_params = 682, num_elems = 7.04B
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Load mm_mlp_adapter from /mnt/netdata/Team/AI/weights/huggingface/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Tune entire model!
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Tune the MLP! The LR of MLP is 2e-05
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Tune the vision tower! LR for ViT is 2e-05.
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Tune the vision tower from layer 12!
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Formatting inputs...Skip in lazy mode
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m train_batch = 128, micro_batch=16,grad_acc=1,world_size=8e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m Load mm_mlp_adapter from /mnt/netdata/Team/AI/weights/huggingface/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bine[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m Parameter Offload: Total persistent parameters: 599040 in 312 params
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m 
  0%|          | 0/9742 [00:00<?, ?it/s]
e[36m(RayTrainWorker pid=304730, ip=10.244.5.110)e[0m 
Loading checkpoint shards: 100%|██████████| 2/2 [01:15<00:00, 34.22s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [01:15<00:00, 37.81s/it]e[32m [repeated 7x across cluster]e[0m
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m ####################################### torch.Size([16, 853, 4096])
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m ####################################### 16
e[36m(RayTrainWorker pid=304737, ip=10.244.5.110)e[0m train_batch = 128, micro_batch=16,grad_acc=1,world_size=8e[32m [repeated 8x across cluster]e[0m
e[36m(RayTrainWorker pid=304738, ip=10.244.5.110)e[0m /mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
2024-10-23 19:43:33,629	ERROR tune_controller.py:1331 -- Trial task failed for trial TorchTrainer_0fc45_00000
Traceback (most recent call last):
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/_private/worker.py", line 2691, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OutOfMemoryError): e[36mray::_Inner.train()e[39m (pid=299096, ip=10.244.5.110, actor_id=8ebbb3b7e9f854e75049462345000000, repr=TorchTrainer)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(OutOfMemoryError): e[36mray::_RayTrainWorker__execute.get_next()e[39m (pid=304732, ip=10.244.5.110, actor_id=9a685290650a898de26c1a0b45000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7ef4d0529030>)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 176, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/mnt/home/stang/workspace/ShareGPT4V/share4v/train/train_ray.py", line 1077, in train
    trainer.train()
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
    loss = self.module(*inputs, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/mnt/home/stang/workspace/ShareGPT4V/share4v/model/language_model/share4v_llama.py", line 219, in forward
    outputs = self.model(
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/mnt/home/stang/workspace/ShareGPT4V/share4v/model/language_model/share4v_llama.py", line 138, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
    outputs = run_function(*args)
  File "/mnt/home/stang/workspace/ShareGPT4V/share4v/model/language_model/share4v_llama.py", line 134, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/mnt/home/stang/workspace/ShareGPT4V/share4v/train/llama_flash_attn_monkey_patch.py", line 88, in forward
    qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 119, in unpad_input
    index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/mnt/net-cloud4/Team/AI/personal/ts/miniconda3/envs/share4v/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 17, in forward
    return torch.gather(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 253.62 GiB. GPU 1 has a total capacty of 94.99 GiB of which 78.92 GiB is free. Including non-PyTorch memory, this process has 16.07 GiB memory in use. Of the allocated memory 13.63 GiB is allocated by PyTorch, and 603.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF



When I use 1 worker 4 gpu,it will also came out GPU out of memory. When use ray train, the GPU memory cost will triple .