Open StellaAthena opened 4 years ago
I thought about this a little bit, and here's how I think this should be implemented: Define a SparseModel subclass of base.Model with a default pruning Strategy, then a sparsify() method that returns a PrunedModel according to a specified strategy (or default)
The benefit of this is that, if we wanted to create a SparseLeNet class, for example, we can do it through multiple inheritance - subclass SparseModel and LeNet at the same time and nothing else needs to be done.
At the same time, we should introduce a new "dummy" pruning strategy in pruning.registry that does nothing, as no further pruning is needed.
Our experiment would look like:
sparse_models = [model.sparsify() for i in range(num_sparse_models)]
and then train each of the sparse models in isolation.
One issue that may arise from this is the runtime. I haven't looked into how a PrunedModel is evaluated - whether it knows to "skip" the zero-ed out weights. But there doesn't seem to be a way around that short of rewriting their internal implementations.
I made a commit (https://github.com/dtch1997/open_lth/commit/fd0c9883303fbdc0f45014f5f9cbc63aa5523fcd) demonstrating this idea. Let me know what you think.
We wish to be able to find lottery tickets in sparse versions of existing models.