LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.28k stars 340 forks source link

SDPA Flash kernel #848

Open finnkauski opened 8 months ago

finnkauski commented 8 months ago

Hi,

The 2.2 release of Torch added a great integration to speed up transformer attention based architectures.

FlashAttentionV2 backend for scaled dot product attention

An example of how it is intended to be used in the Python API is this context manager.

Following down the rabbit hole we land here:

torch._C._set_sdp_use_flash(enabled)

and following this we land here

There are the other mem_efficient and math optimisations which are closely linked and worth exposing.

NOTE: not sure if this is already in the bindings and I just can't find it?

LaurentMazare commented 8 months ago

I don't think this is available yet, and for these very specific operations I'm a bit dubious we will have some good support anytime soon. If you care about using flash-attention, I would recommend switching to candle instead where this is already well supported and give you access to flash-attention v2 which is significantly faster than the first version.

finnkauski commented 8 months ago

The way I've read the info about this seems like it's some form of switch that gets flicked (unknown how it actually works under the hood, but as far as the calls to enable flash attention feels like it's some global state almost) so the Tensor::scaled_dot_product_attention dispatches some other kernels under the hood. I was hoping it would be something they would expose for us to also be able to toggle. This query is not really a now now issue, but I am thinking ahead for what I might need and hence the question.

If you think we can't expose this here then I think we can close this issue.

Side note: I'm a big fan of candle but I'm unfortunately working backwards on my project from candle back to torch as candle needs a few operations such as CUDA ConvTranspose1d that aren't implemented yet for my use-case and my naive kernel for it has been embarrassingly slow! And even without it (i.e. if we took the time it took to run that kernel, I couldn't get performance to the level I could with torch for now). I've go the codebase in two branches now candle and tch for reference, so might be able to contribute some insight into the comparison for my use case in the future.

LaurentMazare commented 8 months ago

What are the ops that you're missing on the candle side? If it's just a matter of Conv1d and ConvTranspose1d, we could have a look at hooking cudnn here, this should bring you the best available kernels for these ops and you would benefit from all the "modern" aspects of candle as an ML framework which tch doesn't have.

finnkauski commented 8 months ago

It was indeed those two mainly for this part of the project. I'm sure other bits might hit some walls too.

Essentially here's the issue on the candle side and it's ConvTranspose1d that I'm missing and obviously the faster the regular Conv1d is the better for my case.

I assumed you folks are swamped with TODO's so was just going to wait until this gets picked up to explore candle further.

LaurentMazare commented 8 months ago

Ah, hooking conv1d to use cudnn is probably the easiest thing to do but won't help the convtranspose1d bit here. I'll have to dig a bit to see how to do the transposed versions in cudnn (I think that's what pytorch does).

LaurentMazare commented 8 months ago

Just to mention that I merged a naive conv-transpose1d cuda implementation on the candle side. It's certainly on the very slow side so can well be a bottleneck for your use case but hopefully I can make a cudnn version for it too.

jquesnelle commented 1 month ago

Would like to see this included in -- I think torch-sys module at the minimum is missing the scaled_dot_product_attention binding because it has the optional scalar for dropout_p that the bindings are filtering out, see https://github.com/LaurentMazare/tch-rs/blob/a4e9362e4acbbde54ab9503ab9e37a10835e7547/gen/gen.ml#L85

Could the gen.ml file be updated to include SDPA 🙏 ?

LaurentMazare commented 1 month ago

I would have thought that scaled_dot_product_attention is already generated, e.g. see here (fwiw optional scalars should not result in discarding the functions but rather these values are not settable in the rust api if there is a default value).