pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
668 stars 87 forks source link

[RFC] Plans for sparsity #143

Open jcaip opened 4 months ago

jcaip commented 4 months ago

Summary

Sparsity, like quantization, offers increased model performance at the expense of some model quality. However, it is not as widely used / researched as a technique, despite offering similar performance benefits. With the recent explosion in model sizes in GenAI, and with quantization pushing 1-bit limits, there has been renewed interest in sparsity, specifically for GPU backend sparsity patterns.

The parallel nature of GPU backends makes accelerating unstructured sparsity difficult. However, there exist specific sparsity patterns (block-wise, semi-structured) that are more amenable to acceleration on GPUs. Over the last year, we’ve integrated these fast sparse kernels into PyTorch Core, so that all users can show up to a with just a few lines of code:

Our goal for torchao.sparsity is to drive research / adoption of these GPU sparsity patterns.

We feel that the main problem current researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like: When should I mask? When/how should I store the compressed representation? Do I want in-place or out-of-place mask updates? How can I call sparse matmul instead of dense?

We hope to change that by providing tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We feel like the above problems can be solved once, by torchao, letting researchers focus on pushing sparse kernel performance or more accurate pruning algorithms.

We're also hoping to create a new extension point by releasing the workflows we have designed with xFormers that enable accelerated sparse training, not just sparse inference. As such, we plan on launching torchao.sparsity with the following features in v0.2:

However, we’d like feedback from the community to set the longer-tem vision of sparsity. Also fee free to chime in with any other thoughts you want to share!


Pruning Algorithms

We plan to host a set of OSS pruning algorithms in torchao. These pruning algorithms should extend the torch.ao.pruning.BaseSparsifier class, like WandaSparsifier. We welcome community contributions for pruning algorithms, provided they extend the BaseSparsifier.

Open Questions:

Recipes / Benchmarks

We have often found pruning to be very model specific, with little generalization across domains. As such we hope to land sparse training recipes for specific models / datasets, showing how different pruning algorithms can be used. We are specifically interested in recipes that compose with quantization.

Additionally, we hope that these benchmark numbers can help first-time users of sparsity better understand the tradeoffs involved and encourage researchers to contribute SOTA pruning algorithms.

Open Questions:

Accelerated Sparse Training

While much work has been done on sparsity for inference, sparsity for training has remained much more challenging. Thanks to the work done by xFormers, we’ve upstreamed fast runtime semi-structured sparsification kernels into PyTorch Core, which allow for prune -> compress -> sparse_mm to happen faster than dense matmul. We also aim to release an example of accelerated sparse training for the OSS community to extend.

Performant Sparse Kernels

There are additional sparsity patterns that may be supported on GPUs, which would require additional fast sparse kernels. We hope that torchao can be a staging ground for these kernels. We plan to upstream these kernels to Core as we see fit, depending on adoption.

Some initial options are:

Open Questions:

cc @supriyar @cpuhrsch @msaroufim @pytorch-labs/team-superblock @danthe3rd @mklasby @ngc92 @hgyhungry

ngc92 commented 4 months ago

maybe the easiest--and also more general-- way to get some feedback towards load-sparse compute-dense is to not change kernels, but use losslessly compressed memory (for the GPUs that support it). If those results look promising, then we can actually tackle custom kernels.

The advantage is that, I think, compressed memory should just work automatically with all existiting kernels, because it is handled at the hardware/driver level. Disadvantage is, of course, that it is a very recent hardware feature.

mklasby commented 4 months ago

Are there changes that need to be made to Sparsifier?

Currently, the sparsifier has a bit of a convoluted and brittle load_state function as its storing state for the parametrized masks. IMO this state should stay with the parametrization as I think this will simplify loading of model checkpoints. We can always call the squash_mask() function if we need to convert it into a dense weight matrix with zeros instead of dense+mask. We talked about this previously, but good to document here I suspect.

Pruning uses parameterizations ( FakeSparsity ), should we switch to MaskedTensor?

AFAIK, MaskedTensor doesn't support some important tensor operations such as matmuls. I think using MaskedTensors may be the preferred route in the long term since the unspecified element syntax is more expressive than FakeSparsity. However, some DST algos (ie.. RigL) requires access to dense weights/grads, for these parametrization is currently a more natural solution at this time.

Global mask updates are difficult to support, is this something researchers care about?

If we create an abstract method _step() (or similar) in BaseSparsifier instead of update_mask(), we can easily extend the class to support layerwise OR global sparsity distributions. Currently, update_mask() is too opinionated in my mind and we can let users simply call update_mask() in _step() if implementing a layerwise algo.

What benchmarks / datasets are interesting to the community? It looks like ViT on ImageNet is the most common architecture.

As a benchmark, ResNet-50/ImageNet is widely cited in literature and would be nice to support out of the box as well. I understand we do not currently have acceleration support for conv layers but at minimum we could show that the sparisifer algorithms converge to the same performance as values tabulated in literature.

Does the community feel there is value in having a suite of sparse microbenchmarks for the different sparsity patterns or just E2E results?

Yes, this is valuable as we often trade off sparse structure for generalization performance. Defining these microbenchmarks would help users find the generalization / latency trade off that is required for their use case.

The masking mechanism is different from the torch.ao.pruning masking mechanism (FakeSparsity), should we unify the two?

If I understand correctly, the prune->compress->sparse_mm framework would work well for sparse training algorithms that do not depend on dense weight or gradient information. However, if dense weight/grad info is required, the compressed sparse tensor representations would be missing this info. I think a potential work-around is to add a feature for quickly converting between the compressed sparse tensor and dense+mask (fake sparsity) representations. If this conversion can be done in a performant manner, we can amortize the it's cost over many mini-batch iterations during training. I.e., prune->compress->train for n mini-batches -> decompress to dense/mask and take single minibatch step as req'd -> update mask -> compress -> train... etc.

We also have different pruning algos in pytorch/pytorch under torch.nn.utils.prune. Ideally, we should strive to integrate these various pruning approaches where possible. I do think FakeSparsity (dense+mask) is still worth supporting as a feature until we have a better work-around for DST algorithms that need dense grad/weight info, even if only intermittently.

M:N / Sparse fan-in kernel - These are similar to 2:4 sparse kernels, but generalized to N:M. While they do not offer the same hardware acceleration as 2:4 sparsity, you can still get memory speedups by sending a compressed representation.

I may be a bit biased here but definitely +1 from me. Perhaps we look to add sparse kernel support once a training recipe is PR'd that demonstrates high generalization performance for a given sparsity pattern?

What about load as sparse, compute as dense kernels?

+1 here for sure. I think these would integrate very well with quantization out of the box.

What about other backends like COO CPU kernels for unstructured sparsity? We believe that we should focus on these M:N / block-sparse GPU patterns in particular.

Currently, DeepSparse is the SOTA in this area. They include a number of additional innovations to accelerate including depth-wise asynchronous execution, pre-loading the input data to hide latency via CPU pipelining, compressing sparse activations into a CSR format on-the-fly, and keeping convolutional kernels in L2 cache. More info. However, the license is non-permissive for commercial use.

Since CPU inference is still a very important use case, I think it makes sense to provide support here that would provide an OSS alternative to Neural Magic. I agree lower priority than fast GPU kernels.