UDC-GAC / venom

A Vectorized N:M Format for Unleashing the Power of Sparse Tensor Cores
Apache License 2.0
31 stars 5 forks source link

Batched Sparse Matrix - Batched Sparse Matrix multiplication #2

Open iminfine opened 8 months ago

iminfine commented 8 months ago

Thanks for your great work!

Is it possible to make Spatha support Batched Sparse Matrix - Batched Sparse Matrix multiplication?

I am currently engaged in the development of a novel cross-attention mechanism for transformers. In this context, consider the shape of q and k as (b, n, l, d), where b represents the batch number, n denotes the head number, l signifies the sequence length, and d stands for the feature dimensions.

The code mechanism employed to compute attention involves the application of the F.relu activation function to both q and k, followed by reshaping operations. Specifically, it can be expressed as follows:

F.relu(q).reshape(b, n l, d) @ F.relu(k).reshape(b, n l, d).transpose(-2, -1)

However, the reshape operation results in the multiplication of two large matrices, incurring a substantial GPU memory cost and time-consuming. This presents a notable drawback in terms of computational efficiency and resource utilization.

The aforementioned attention function can be conceptualized as the multiplication between two sparse matrices, considering the involvement of the ReLU function. To address the memory and computation issue, I am seeking to leverage the Batched Sparse Matrix - Batched Sparse Matrix multiplication feature using the SOTA technical. This enhancement would optimize the computation and alleviate the strain on GPU memory, making the implementation more efficient and scalable.