pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 480 forks source link

Implement `torchvision.ops.nms` in torch_xla2. #8278

Closed qihqi closed 3 weeks ago

qihqi commented 1 month ago

🚀 Feature

nms: https://pytorch.org/vision/stable/generated/torchvision.ops.nms.html is an op used by many torchvision models. It would be nice for those models to run on torchxla2.

Here is an implementation of it in Jax from MLperf submission: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py