Seeking recommendation for training Detectron2 with Ray Tune

Hi Team,

I am currently migrating our legacy Detectrong2-based object detection training pipeline into a new unified training pipeline leveraging Ray Tune and Ray Train. For Phase 1, we want to integrate Detectron2 with Ray Tune. So we have a unified interface that uses Ray to train with other frameworks. In Phase 2, we plan to incorporate Ray Tune to auto-scale and tune parameters in Detectron2’s global shared config object. We are still at Phase 1.

We came across the following related GitHub issues from both the Ray project and Detectron2, which indicated that due to the design and implementation choice by Detectron2 to use a shared global config object to encapsulate all the training properties, it might be incompatible with Ray’s distributed fashion. For example:

# import some common detectron2 utilities
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("balloon_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 300    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

Things are getting even more complicated with shared data when we need to register the data globally:

# Code from official Detectron2 Tutorial: https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=PIbAM2pv-urF
 
for d in ["train", "val"]:
    DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
    MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
balloon_metadata = MetadataCatalog.get("balloon_train")

Any recommendation for how to construct a Ray Train Function with Detectron2? We want to reuse the legacy code where we created a Detectron Trainer with the DefaultTrainer. For now, in Phase 1 we just want to integrate Detectron2 with Ray Tune so that we can use Ray Checkpoints across the end-to-end multi-stages training pipeline.

Github Issues References:

Ray Project Github issue:
# [core] modifications to global variable has no effect #15946

Detectron2 Github issue:
# Dataset is not registered when performing hyperparameter tuning #3057

I noticed that this is more like a brainstorming request, much appreciate any input from the team.

Thank you,
Heng

In addition, I am currently exploring to extract the trainer from our legacy code and since it is using the basic form of Detectron2 trainer, SimpleTrainer behind the scene. Which seems fit the requirements from Ray Train’s TorchTrainer.

Hi everyone, I have uploaded my first attempt at training Detectron2 models with Ray Train. And I pushed the demo code here in Github. It is the naive integration approach to meet the goal for our Phase 1 development, which leverage Ray Tune to train Detectron2 models. The piping for using Detectron2’s SimpleTrainer in Ray’s TorchTrainer is working. I can train Detectron2 models, get Ray checkpoints, and log training progress in Tensorboard. However, I also noticed that I could not adequately leverage Tune’s scaling and tuning capabilities with my current naive implementation. For example, I don’t think I can scale to use more than 1 worker in ScalingConfig.

I am seeing the following error when I try to run with 2 workers:

    _LOCAL_PROCESS_GROUP is not None
AssertionError: Local process group is not created! Please use launch() to spawn processes!
2023-01-10 11:23:45,881 ERROR tune.py:758 -- Trials did not complete: [TorchTrainer_08b5d_00000]
Result(metrics={'trial_id': '08b5d_00000'}, error=RayTaskError(AssertionError)(AssertionError('Local process group is not created! Please use launch() to spawn processes!')), log_dir=PosixPath('/heng/output/RayDetectron2/ray_results/Detector_Training_Demo/TorchTrainer_08b5d_00000_0_2023-01-10_11-22-53'))

With some code example here, I hope I can get some pointers from the Ray team for how to move on to incorporate Ray Tune to auto-scale and tune parameters for training Detectron2 models in Ray.

hi @heng2j , can you share more about what you see from ray status to describe your current cluster setup, and the configs you used attempting to kick off multiple trials ?

Hi @Jiao_Dong , thank you for getting back to me. I set all the configs within these few lines of code. And what caused the AssertionError in Local process group is when I tried to use more than 1 workers in the ScalingConfig for the TorchTrainer here.

Here is what I see from ray status when I am only use 1 worker:

1 Worker Ray Status output (When it is working):

heng$ ray status
======== Autoscaler status: 2023-01-12 09:07:39.373416 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_91316ec5765fa1c534650fd9bdcb675f9c825779ea7c472de028dfd3
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 1.0/2.0 GPU (1.0 used of 1.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/72.001 GiB memory
 0.00/34.849 GiB object_store_memory

Demands:
 (no resource demands)

> 1 Worker Ray Status output (When it is not working):

heng$ ray status
======== Autoscaler status: 2023-01-12 09:10:15.209237 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_f222d3752c8e866e00ce34aec8e0c15a59f4dd9da4af6b20ed35195b
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 0.0/2.0 GPU (0.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/71.777 GiB memory
 0.00/34.753 GiB object_store_memory

Demands:
 {'CPU': 1.0}: 1+ pending tasks/actors (1+ using placement groups)
 {'CPU': 1.0} * 1, {'GPU': 1.0} * 2 (PACK): 1+ pending placement groups


...
# try to view ray status few seconds later as we keep seeing the AssertionError

heng$ ray status
======== Autoscaler status: 2023-01-12 09:10:20.216686 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_f222d3752c8e866e00ce34aec8e0c15a59f4dd9da4af6b20ed35195b
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 2.0/2.0 GPU (2.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/71.777 GiB memory
 0.00/34.753 GiB object_store_memory

Demands:
 {'GPU': 1.0}: 2+ pending tasks/actors (2+ using placement groups)

I hope these info helps. And please feel free to let me know any other things that I can provide.

Detectron2 expects you to start processes with its launch function which will set a local process group. But ray expects processes to be started by ray.

In provide {create,get}_local_process_group by ppwwyyxx · Pull Request #4742 · facebookresearch/detectron2 · GitHub I exposed a function so that users who do not use the launch function can also set this local process group. Users have to call create_local_process_group manually in that case.

Hi @ppwwyyxx, thank you so much for your quick turnaround on the PR from the Detectron2 side and your recommendation for mapping the number of workers aka GPUs from Ray configs to Detectron2 with the new function create_local_process_group()

I just want to make sure I understand the use case of create_local_process_group().

It is ideal to add the following lines before I set up the Detectron2 trainer? In my example line 168?


    # Set defult training params
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))

    ...
    # Newly added line for creating local process group for Detectron2
    from detectron2.utils.comm import create_local_process_group # show the import here reference 
    create_local_process_group(config.get("num_workers",1))

    ...    
    # Set up Detectron2 trainer
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)

It is ideal to add the following lines before I set up the Detectron2 trainer? In my example line 168?

Yes that should be fine. If the same error comes up, just add it earlier (before get_cfg())

Hi @ppwwyyxx and @Jiao_Dong , so I pulled in the changes in this PR and tried to call create_local_process_group(num_workers_per_machine=2) In my example line 168. And also before get_cfg (). But I am seeing the following RayTaskError

However, I am seeing the following RayTaskError.

2023-01-13 11:30:28,433 ERROR trial_runner.py:993 -- Trial TorchTrainer_72b61_00000: Error processing event.
ray.exceptions.RayTaskError(RuntimeError): ray::_Inner.train() (pid=269287, ip=10.1.60.61, repr=TorchTrainer)
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 355, in train
    raise skipped from exception_cause(skipped)
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(RuntimeError): ray::RayTrainWorker._RayTrainWorker__execute() (pid=269568, ip=10.1.60.61, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7f2e1cc28460>)
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/ray/train/_internal/worker_group.py", line 31, in __execute
    raise skipped from exception_cause(skipped)
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/heng/code/ray_detectron2/ray_detectron2/ray_detectron2.py", line 176, in train_func
    trainer = DefaultTrainer(cfg)
  File "/heng/code/detectron2/detectron2/engine/defaults.py", line 288, in __init__
    model = DistributedDataParallel(
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 646, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/heng/env/miniconda3/envs/dev_env/lib/python3.10/site-packages/torch/distributed/utils.py", line 89, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1191, invalid usage, NCCL version 2.10.3
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
2023-01-13 11:30:28,547 ERROR tune.py:773 -- Trials did not complete: [TorchTrainer_72b61_00000]
Result(metrics={'trial_id': '72b61_00000'}, error=RayTaskError(RuntimeError)(RuntimeError('NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1191, invalid usage, NCCL version 2.10.3\nncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).')), log_dir=PosixPath('/heng/output/RayDetectoron2/ray_results/Detector_Training_Demo/TorchTrainer_72b61_00000_0_2023-01-13_11-29-27'))

And here are my Ray status output:

heng$ ray status
======== Autoscaler status: 2023-01-13 11:36:13.301919 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_40d4234e640ff9905914e83fe4188fde1896001b3e55ae7b986d9c39
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 2.0/2.0 GPU (2.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/67.221 GiB memory
 0.00/32.800 GiB object_store_memory

Demands:
 {'GPU': 1.0}: 2+ pending tasks/actors (2+ using placement groups)

...

heng$ ray status
======== Autoscaler status: 2023-01-13 11:36:18.311808 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_40d4234e640ff9905914e83fe4188fde1896001b3e55ae7b986d9c39
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 2.0/2.0 GPU (2.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/67.221 GiB memory
 0.00/32.800 GiB object_store_memory

Demands:
 (no resource demands)

...

heng$ ray status
======== Autoscaler status: 2023-01-13 11:36:38.351966 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_40d4234e640ff9905914e83fe4188fde1896001b3e55ae7b986d9c39
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 0.0/2.0 GPU (0.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/67.221 GiB memory
 0.00/32.800 GiB object_store_memory

Demands:
 {'CPU': 1.0}: 1+ pending tasks/actors (1+ using placement groups)

...

======== Autoscaler status: 2023-01-13 11:54:07.628662 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_1e68ba84b2443de2f877fc31cb6267916a128b09206ad5d936d09f6c
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/64.0 CPU (1.0 used of 1.0 reserved in placement groups)
 0.0/2.0 GPU (0.0 used of 2.0 reserved in placement groups)
 0.0/1.0 accelerator_type:G
 0.00/66.886 GiB memory
 0.00/32.657 GiB object_store_memory

Demands:
 {'CPU': 1.0}: 1+ pending tasks/actors (1+ using placement groups)
 {'CPU': 1.0} * 1, {'GPU': 1.0} * 2 (PACK): 1+ pending placement groups


FYI, using a single worker still works with these latest changes.

You’re seeing NCCL errors from pytorch DDP level. Can you follow similar debugging instructions as ncclInvalidUsage of torch.nn.parallel.DistributedDataParallel - PyTorch Forums and show us what export NCCL_DEBUG=INFO output look like ?

Hi @Jiao_Dong,

Thank you for the recommendation! This is new to me regarding NCCL debugging. Here is my NCCL_DEBUG=INFO output:

(RayTrainWorker pid=409108) _______:409108:409198 [0] NCCL INFO Bootstrap : Using enp70s0:10.1.60.61<0>
(RayTrainWorker pid=409108) _______:409108:409198 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
(RayTrainWorker pid=409108) _______:409108:409198 [0] NCCL INFO NET/IB : No device found.
(RayTrainWorker pid=409108) _______:409108:409198 [0] NCCL INFO NET/Socket : Using [0]enp70s0:10.1.60.61<0>
(RayTrainWorker pid=409108) _______:409108:409198 [0] NCCL INFO Using network Socket
(RayTrainWorker pid=409108) 
(RayTrainWorker pid=409108) _______:409108:409465 [0] init.cc:521 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 1000
(RayTrainWorker pid=409108) _______:409108:409465 [0] NCCL INFO init.cc:904 -> 5
(RayTrainWorker pid=409108) _______:409108:409465 [0] NCCL INFO group.cc:72 -> 5 [Async thread]
(RayTrainWorker pid=409107) _______:409107:409199 [0] NCCL INFO Bootstrap : Using enp70s0:10.1.60.61<0>
(RayTrainWorker pid=409107) _______:409107:409199 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
(RayTrainWorker pid=409107) _______:409107:409199 [0] NCCL INFO NET/IB : No device found.
(RayTrainWorker pid=409107) _______:409107:409199 [0] NCCL INFO NET/Socket : Using [0]enp70s0:10.1.60.61<0>
(RayTrainWorker pid=409107) _______:409107:409199 [0] NCCL INFO Using network Socket
(RayTrainWorker pid=409107) NCCL version 2.10.3+cuda11.3
(RayTrainWorker pid=409107) 
(RayTrainWorker pid=409107) _______:409107:409464 [0] init.cc:521 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1000
(RayTrainWorker pid=409107) _______:409107:409464 [0] NCCL INFO init.cc:904 -> 5
(RayTrainWorker pid=409107) _______:409107:409464 [0] NCCL INFO group.cc:72 -> 5 [Async thread]


I have the similar issue as shown in the post. I guess I will have to manually set the model and rank to devices? I may need some suggestions on how to achieve this in Ray using Detectron2 models.

This is the cause of your NCCL error. I’m not super familiar with detectron2 codebase, but Ray Train/Tune typically launches workers with one process per GPU. Can you try with same config when you launch from detectron2 ? For example, in a way that only one cuda device is visible to each process.

Thank you, @Jiao_Dong for the suggestion. With the new create_local_process_group(), we bypassed the launch() where we used to set num_gpus_per_machine and num_machines. Since we set the num_workers_per_machine in create_local_process_group() now like the following:

We probably will also need a way to map the values back to world_size since get_world_size() is returning 1 for me now and I can’t pass the assertion.

@functools.lru_cache()
def create_local_process_group(num_workers_per_machine: int) -> None:
    """
    Create a process group that contains ranks within the same machine.

    Detectron2's launch() in engine/launch.py will call this function. If you start
    workers without launch(), you'll have to also call this. Otherwise utilities
    like `get_local_rank()` will not work.

    This function contains a barrier. All processes must call it together.

    Args:
        num_workers_per_machine: the number of worker processes per machine. Typically
          the number of GPUs.
    """
    global _LOCAL_PROCESS_GROUP
    assert _LOCAL_PROCESS_GROUP is None
    assert get_world_size() % num_workers_per_machine == 0 

This is my sample training code for how I train in Detectron2:


...



from detectron2.engine import DefaultTrainer
from detectron2.engine import hooks

from detectron2.utils.comm import create_local_process_group # show the import here as reference 
create_local_process_group(num_workers_per_machine=2)

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.OUTPUT_DIR = "./detector_output"
cfg.DATASETS.TRAIN = ("license_plates_train",)
cfg.DATASETS.TEST =  ("license_plates_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 300    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (Vehicle registration plate). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()



@ppwwyyxx, any recommendations?

That would be wrong. I don’t know much about ray, but you’ll need to check whether ray initializes pytorch process group for you. (Try print(torch.distributed.is_initialized()))
If ray initializes process group for users, then it’s doing wrong because the group should have a world_size of 2.
If ray does not, then you’ll need to do it yourself by calling torch.distributed.init_process_group with the “right” parameters. In theory ray is supposed to provide enough info to obtain these parameters (especially the init_method in multi-machine setting), but if ray does not, you can check detectron2’s launch and do it yourself.

Thank you @ppwwyyxx. For the above example, it used only Detectron2 and I didn’t run the training pipeline with Ray. And thanks to your recommendation, I am looking into how detectron2’s launch() is constructing the parameters to call torch.distributed.init_process_group.

is there any update on above issue is it working? @heng2j
i want to start using ray but following your topic. want to know is there any sucess on this or not?
@ppwwyyxx