microsoft / mttl

Building modular LMs with parameter-efficient fine-tuning.
MIT License
73 stars 7 forks source link

Sparse masks #108

Open oleksost opened 1 week ago

oleksost commented 1 week ago

Implements sparse masks in 3 different ways:

Also implements mask updates. Currently, only SNIP updater is implemented and SPieL is in the pipeline.

TODOs:

So ScatteredSparseLinearModule is the fastest now but spops SparseLinearModule uses the least memory.


Profilled block sparse mult. with profile_block_sparcity.py: stk and triton block sparse outperform naive torch.matmul (see profile_block_sparcity.py): image