Dao-AILab / causal-conv1d

Causal depthwise conv1d in CUDA, with a PyTorch interface
BSD 3-Clause "New" or "Revised" License
332 stars 61 forks source link

Questions Regarding Kernel Size and Data Type Support #8

Closed radarFudan closed 11 months ago

radarFudan commented 11 months ago

Hello! Thank you so much for this helpful repo for mamba.

Questions:

  1. Kernel Size Limitation: I noticed that the kernel size is set to a maximum of 4. Could you provide insight into why this limitation exists? Is it due to storage efficiency, or are there compatibility issues with larger kernel sizes?
  2. Data Type Support: a. Currently, the code supports float32, bfloat16, and float16. Is there a plan to include support for float64 in the future? b. Given that Mamba can work with complex weights, I was wondering if there is potential for the causal convolution to support complex kernels as well.
tridao commented 11 months ago

I've updated the README.

  1. Kernel size 2, 3, 4 are supported. Larger kernel sizes are just more annoying to implement. We read in 4 or 8 input elements per thread, so limiting kernel size to 4 makes it easier to implement.
  2. Dtype fp32, fp16, bf16 are supported. There's no plan to support float64, or complex types.

The depthwise conv1d implemented here is equivalent to 1 line of Pytorch (see README), so you can always just call Pytorch for the cases not supported here. The goal of this repo is just to make it fast for some cases used in Mamba (and a few other model architectures).