Open finnkauski opened 8 months ago
I don't think this is available yet, and for these very specific operations I'm a bit dubious we will have some good support anytime soon. If you care about using flash-attention, I would recommend switching to candle instead where this is already well supported and give you access to flash-attention v2 which is significantly faster than the first version.
The way I've read the info about this seems like it's some form of switch that gets flicked (unknown how it actually works under the hood, but as far as the calls to enable flash attention feels like it's some global state almost) so the Tensor::scaled_dot_product_attention
dispatches some other kernels under the hood. I was hoping it would be something they would expose for us to also be able to toggle. This query is not really a now now issue, but I am thinking ahead for what I might need and hence the question.
If you think we can't expose this here then I think we can close this issue.
Side note:
I'm a big fan of candle
but I'm unfortunately working backwards on my project from candle back to torch as candle needs a few operations such as CUDA ConvTranspose1d that aren't implemented yet for my use-case and my naive kernel for it has been embarrassingly slow! And even without it (i.e. if we took the time it took to run that kernel, I couldn't get performance to the level I could with torch for now). I've go the codebase in two branches now candle
and tch
for reference, so might be able to contribute some insight into the comparison for my use case in the future.
What are the ops that you're missing on the candle side? If it's just a matter of Conv1d and ConvTranspose1d, we could have a look at hooking cudnn here, this should bring you the best available kernels for these ops and you would benefit from all the "modern" aspects of candle as an ML framework which tch doesn't have.
It was indeed those two mainly for this part of the project. I'm sure other bits might hit some walls too.
Essentially here's the issue on the candle
side and it's ConvTranspose1d that I'm missing and obviously the faster the regular Conv1d is the better for my case.
I assumed you folks are swamped with TODO's so was just going to wait until this gets picked up to explore candle further.
Ah, hooking conv1d to use cudnn is probably the easiest thing to do but won't help the convtranspose1d bit here. I'll have to dig a bit to see how to do the transposed versions in cudnn (I think that's what pytorch does).
Just to mention that I merged a naive conv-transpose1d cuda implementation on the candle side. It's certainly on the very slow side so can well be a bottleneck for your use case but hopefully I can make a cudnn version for it too.
Would like to see this included in -- I think torch-sys
module at the minimum is missing the scaled_dot_product_attention
binding because it has the optional scalar for dropout_p
that the bindings are filtering out, see https://github.com/LaurentMazare/tch-rs/blob/a4e9362e4acbbde54ab9503ab9e37a10835e7547/gen/gen.ml#L85
Could the gen.ml file be updated to include SDPA 🙏 ?
I would have thought that scaled_dot_product_attention
is already generated, e.g. see here (fwiw optional scalars should not result in discarding the functions but rather these values are not settable in the rust api if there is a default value).
Hi,
The 2.2 release of Torch added a great integration to speed up transformer attention based architectures.
An example of how it is intended to be used in the Python API is this context manager.
Following down the rabbit hole we land here:
and following this we land here
There are the other
mem_efficient
andmath
optimisations which are closely linked and worth exposing.NOTE: not sure if this is already in the bindings and I just can't find it?