masked linear: compute and memory inefficient as it has to update sparse weights that are kept in dense format
scattered sparse: more memory and compute efficient, keeps sparse weights only and usestorch.scatter_add to only update the sparse weights
sparse linear: uses spops kernels to make things even faster. This also supports structured operations, so block spacity should be fast out of the box (not 100% sure, will double check), This does not work on some GPUs (spops compiled for sm_80 architectures like A100)
Also implements mask updates. Currently, only SNIP updater is implemented and SPieL is in the pipeline.
TODOs:
[x] Tests are not implemented yet.
[ ] When updating the mask periodically with SNIP, shall we accumulate all weight updates for all used masks so far on CPU? (like in masked linear case by default)
[x] Do some profiling
[x] Make sure block structure is leveraged
[ ] SPieL mask updater
Currently, manual profiler gives me this (for GPT-neo 125M with 0.5% sparcity):
Implements sparse masks in 3 different ways:
torch.scatter_add
to only update the sparse weightsAlso implements mask updates. Currently, only SNIP updater is implemented and SPieL is in the pipeline.
TODOs:
Currently, manual profiler gives me this (for GPT-neo 125M with 0.5% sparcity):
SparseLinearModule
(spops) with regular sparsity - Runtime: 0.066590s, Allocated Memory: 4552.14MB, Reserved Memory: 4645.19MBSparseLinearModule
(spops) with blcok sparsity - Runtime: 0.067642s, Allocated Memory: 4553.58MB, Reserved Memory: 4645.19MBScatteredSparseLinearModule
with block sparsity - Runtime: 0.052826s, Allocated Memory: 4734.14MB, Reserved Memory: 4817.16MBScatteredSparseLinearModule
with regular sparsity - Runtime: 0.052953s, Allocated Memory: 4734.66MB, Reserved Memory: 4817.16MBMaskedLinear
with regular sparsity - Runtime: 0.056629s, Allocated Memory: 4892.71MB, Reserved Memory: 4970.25MBMaskedLinear
with block sparsity - Runtime: 0.055440s, Allocated Memory: 4889.36MB, Reserved Memory: 4978.64MBSo
ScatteredSparseLinearModule
is the fastest now but spopsSparseLinearModule
uses the least memory.Profilled block sparse mult. with profile_block_sparcity.py: stk and triton block sparse outperform naive torch.matmul (see
profile_block_sparcity.py
):