daphne-eu / daphne

DAPHNE: An Open and Extensible System Infrastructure for Integrated Data Analysis Pipelines
Apache License 2.0
67 stars 62 forks source link

Codegen for sparse-dense cross-entropy #919

Open AlexRTer opened 3 days ago

AlexRTer commented 3 days ago

This PR adds a codegen pass to lower the following expression to a fused operator that exploits sparsity sum(CSRMat * ln(denseLhs @ t(denseRhs))). It can be run simply by enabling the codegen pipeline (--mlir-codegen) and ensuring the lhs of the elementwise multiplication is a CSRMatrix (currently --select-matrix-repr) with the corresponding cli flags. By computing the sum directly, the pass not only avoids materializing potentially large dense matrices in the dense, right matrix multiplication, it also only computes the necessary dot products corresponding to non-zero entries in the CSRMatrix. Thus, it uses constant memory and reduces runtime significantly.

An example script to test the results (--explain mlir_codegen is optional to show the generated IR):

// RUN: ./bin/daphne --select-matrix-repr --mlir-codegen --explain mlir_codegen ./fileName.daphne

seed = 1;
sparsity = 1e-6;
sparseRows = 10_000;
sparseCols = 10_000;
hiddenDim = 20;

startGeneratingMatrices = now();
sparseLhs = rand(sparseRows, sparseCols, 0.0, 1.0, sparsity, seed);

DenseU = rand(sparseRows, hiddenDim, 0.0, 1.0, 1.0, seed + 1); // sparsity: 1.0
DenseV = rand(sparseCols, hiddenDim, 0.0, 1.0, 1.0, seed + 2); // sparsity: 1.0
endGeneratingMatrices = now();

startCalc = now();
res = sum(sparseLhs * ln(DenseU @ t(DenseV)));
endCalc = now();

print(res);
print("sparse dim: " + sparseRows + "x" + sparseCols + " (sparsity: " + sparsity + "), dense dim: " + sparseRows + "x" + hiddenDim + "@" + hiddenDim + "x" + sparseCols + "->" + sparseRows + "x" + sparseCols);
print("time to generate matrices: " + as.f64(endGeneratingMatrices - startGeneratingMatrices) * 1e-9 + ", comp. time: " + as.f64(endCalc - startCalc) * 1e-9);

A more thorough description will be given once some tests have been added.