huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.91k stars 26.27k forks source link

Error in object detection example script #32525

Open ohhan777 opened 1 month ago

ohhan777 commented 1 month ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

When running ./examples/pytorch/object_detection/run_object_detection_no_trainer.py in the transformers source code, it works well on a single GPU, but when running on multiple GPUs, an error occurs at metric.compute() in evaluation_loop() function.

Run: accelerate launch run_object_detection_no_trainer.py --ignore_mismatched_sizes

Expected behavior

[rank0]: Traceback (most recent call last):
[rank0]:   File "/gpfs/home/ohhan/ai/PyProjects/transformers/examples/pytorch/object-detection/run_object_detection_no_trainer.py", line 782, in <module>
[rank0]:     main()
[rank0]:   File "/gpfs/home/ohhan/ai/PyProjects/transformers/examples/pytorch/object-detection/run_object_detection_no_trainer.py", line 708, in main
[rank0]:     metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/ai/PyProjects/transformers/examples/pytorch/object-detection/run_object_detection_no_trainer.py", line 211, in evaluation_loop
[rank0]:     metrics = metric.compute()
[rank0]:               ^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 628, in wrapped_func
[rank0]:     with self.sync_context(
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/contextlib.py", line 137, in __enter__
[rank0]:     return next(self.gen)
[rank0]:            ^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 599, in sync_context
[rank0]:     self.sync(
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 548, in sync
[rank0]:     self._sync_dist(dist_sync_fn, process_group=process_group)
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/detection/mean_ap.py", line 1029, in _sync_dist
[rank0]:     super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)  # type: ignore[arg-type]
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 452, in _sync_dist
[rank0]:     output_dict = apply_to_collection(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
[rank0]:     return _apply_to_collection_slow(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
[rank0]:     v = _apply_to_collection_slow(
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 125, in _apply_to_collection_slow
[rank0]:     v = _apply_to_collection_slow(
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 96, in _apply_to_collection_slow
[rank0]:     return function(data, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torchmetrics/utilities/distributed.py", line 127, in gather_all_tensors
[rank0]:     torch.distributed.all_gather(local_sizes, local_size, group=group)
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 79, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/gpfs/home/ohhan/miniconda3/envs/ohhan-ai/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 3108, in all_gather
[rank0]:     work = group.allgather([tensor_list], [tensor])
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: No backend type associated with device type cpu
qubvel commented 1 month ago

Hi, @ohhan777, thanks for reporting the issue!

Could you try adding gather for metrics function in the script before using nested_to_cpu? https://huggingface.co/docs/accelerate/v0.33.0/en/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics

smth like

outputs = accelerator.gather_for_metrics(outputs)
labels = accelerator.gather_for_metrics(batch["labels"])

Let me know if that helps, thanks!

ohhan777 commented 1 month ago

I made the suggested modifications, but the same issue still persists. The simplest solution seems to be processing the data on the GPUs without moving it to the CPU. I resolved the issue by commenting out the nested_to_cpu() function and making a one-line modification in the convert_bbox_yolo_to_pascal() function as follows:

boxes = boxes * torch.tensor([[width, height, width, height]]).to(boxes.device)