SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
345 stars 26 forks source link

Incorporating into mainline pytorch? #149

Open jxtps opened 2 months ago

jxtps commented 2 months ago

Hi, excellent project, love NeighborhoodAttention2D, this should be a foundational building block in computer vision, right next to Conv2d etc. So to that end: what needs to happen to get this incorporated into mainline pytorch?

The reason I ask is that while the project in its current form is amazing, it is very hard to use it in a production context: we develop on windows + python, train on linux + python, and deploy to linux + java/c++ ( https://github.com/bytedeco/javacpp-presets/tree/master/pytorch ), which makes incorporating a custom op like this a little tricky.

So what do we need to do and how can I help out?

(Don't mean to barge in, would just really like to be able to use it ;)

CC: @zou3519

alihassanijr commented 2 months ago

Thank you for your interest.

We agree! We think Neighborhood Attention is a powerful and flexible primitive and we'd certainly like to see it adopted in more applications. We've seen some pretty great applications of NA, but unfortunately not as wide as we had hoped.

Our understanding is that until Neighborhood Attention is used in more applications, it would be too specific a feature to request in PyTorch. We're certainly happy to help drive that integration of course, but it is not entirely up to us unfortunately.

That said, there will be challenges in doing so, given that certain design choices made in kernels that NATTEN packages may make it more challenging; our binary size and number of symbols even are too large for a package as simple as NATTEN. But those are just details.

The reason I ask is that while the project in its current form is amazing, it is very hard to use it in a production context: we develop on windows + python, train on linux + python, and deploy to linux + java/c++ ( https://github.com/bytedeco/javacpp-presets/tree/master/pytorch ), which makes incorporating a custom op like this a little tricky.

You're very right, it is tricky to use NA without NATTEN and NATTEN is just limited to being a PyTorch extension, so frameworks that are interoperable with PyTorch, which seems to be your case, would likely run into issues using those.

If the issue is something related to NATTEN, we'd certainly love to try and help, just drop the details in an issue.

So what do we need to do and how can I help out?

(Don't mean to barge in, would just really like to be able to use it ;)

Not at all; I'm actually unsure. I'll CC @Chillee as well since other than @zou3519 he's the only one else I've been in contact with regarding NATTEN and PyTorch to see if he has any ideas?

Chillee commented 2 months ago

@alihassanijr I would probably agree that it would be unlikely to add Natten into core (at least, yet!). I think that the new attention API we're adding should allow NATTEN to be expressed using core PyTorch ops with reasonable perf.

jxtps commented 2 months ago

the new attention API we're adding

I'm a bit out of the loop here - is there a neighborhood-sliding-window-2d-attention-style API about to land in mainline pytorch? That would be huge!? Any pointers to where that is being discussed & developed?

Is it FlexAttention?

alihassanijr commented 2 months ago

Thanks @Chillee -- yes exactly; For most of NATTEN's lifetime, the kernels it provided were definitely not production level, and even though that changed with CUTLASS (more FNA than implicit GEMM, but effectively both), now the challenge is like I said both in binary size (NATTEN's still stuck in the no-JIT era) and also the matter of figuring out autotuning, etc.