Open LaurentMazare opened 1 year ago
You can already do conv2d backprop with the existing methods (See https://github.com/coreylowman/dfdx/blob/main/src/tensor_ops/conv2d/cudnn_kernel.rs#L91).
Conv1d - I'm not sure this exists in cudnn? Not sure about flash attention existence either (or at least it didn't when I last checked cudnn)
But yes open to any contributions here!
Ah great that the conv2d backward step is already there, we'll add it to candle.
For the conv1d, there is some support for Nd convolution I think, e.g. in the cudarc ffi so hopefully having a safe api around this would enable 1d convolutions.
For flash attention, I meant this fused flash attn fprop, though I've actually never used it.
Oh sweet I missed the Nd convolution, nice! Should be able to add that.
If I'm understanding the flash attn thing, it seems like that is something detected at runtime if you're using the cudnn graph API?
Pooling would be great to have as well. It seems like it could have a similar api as conv.
@coreylowman would you take a PR for adding nd pooling?
@kckeiks of course, any and all prs welcome
Thanks for this amazing crate, it's been instrumental to candle. We've recently added a feature to use the cudnn conv2d which sped things up a lot compared to our handcrafted kernel, and would like to have cudnn support for more ops. Mostly interested in: