pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 480 forks source link

Implement `torchvision.ops.roi_align` in torchxla2 #8288

Open qihqi opened 1 month ago

qihqi commented 1 month ago

🚀 Feature

https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi_align#torchvision.ops.roi_align

Few ideas:

  1. Use torch decomposition in here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L115 ; tried this and found out jax OOMs pointing here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L74 so the issue seems that the advanced indexing used here creates large intermediaries. Torch side needed a "loop-less" impl to help with inductor, we could actually rewrite it using jax.vmap and jax.lax.fori_loop.
  2. Start from this jax implementation: https://github.com/google-research/scenic/blob/74225e8e71ba27a76abd62e6bc56e8a64c4cc19e/scenic/projects/baselines/centernet/modeling/roi_align.py#L103 but this one takes output_size as int instead of tuple of int (i.e. it assumes width and height is the same) so it will need some modification.

Motivation

Pitch

Alternatives

Additional context