coreylowman / cudarc

Safe rust wrapper around CUDA toolkit
Apache License 2.0
627 stars 77 forks source link

More cudnn ops #178

Open LaurentMazare opened 1 year ago

LaurentMazare commented 1 year ago

Thanks for this amazing crate, it's been instrumental to candle. We've recently added a feature to use the cudnn conv2d which sped things up a lot compared to our handcrafted kernel, and would like to have cudnn support for more ops. Mostly interested in:

coreylowman commented 1 year ago

You can already do conv2d backprop with the existing methods (See https://github.com/coreylowman/dfdx/blob/main/src/tensor_ops/conv2d/cudnn_kernel.rs#L91).

Conv1d - I'm not sure this exists in cudnn? Not sure about flash attention existence either (or at least it didn't when I last checked cudnn)

But yes open to any contributions here!

LaurentMazare commented 1 year ago

Ah great that the conv2d backward step is already there, we'll add it to candle.

For the conv1d, there is some support for Nd convolution I think, e.g. in the cudarc ffi so hopefully having a safe api around this would enable 1d convolutions.

For flash attention, I meant this fused flash attn fprop, though I've actually never used it.

coreylowman commented 1 year ago

Oh sweet I missed the Nd convolution, nice! Should be able to add that.

If I'm understanding the flash attn thing, it seems like that is something detected at runtime if you're using the cudnn graph API?

kckeiks commented 4 months ago

Pooling would be great to have as well. It seems like it could have a similar api as conv.

kckeiks commented 4 months ago

@coreylowman would you take a PR for adding nd pooling?

coreylowman commented 4 months ago

@kckeiks of course, any and all prs welcome