FluxML / FluxML-Community-Call-Minutes

The FluxML Community Team repo
50 stars 4 forks source link

Pruning methods/library #13

Closed darsnack closed 2 years ago

darsnack commented 3 years ago

I'm going to start looking at creating a pruning package with Flux. Will update this soon with more details.

DrChainsaw commented 3 years ago

For structured pruning, this is pretty powerful as it handles size alignment between layers in an optimal fashion (through JuMP and Cbc): https://github.com/DrChainsaw/NaiveNASflux.jl

There are only two built in methods as of now, but the general API is to just provide a function which returns the value per neuron when given a vertex (i.e layer) of the model. See the first example.

darsnack commented 3 years ago

Looks like this is a substantial framework for activation based pruning where the mask is not known a priori. Is my understanding correct?

DrChainsaw commented 3 years ago

I would say the framework itself is something like a more general version of network morphisms and pruning metrics themselves are not really core functionality.

When it comes to selecting which neurons to keep in case you decrese the size of at least one layer (e.g. the pruning use case), you can supply any metric as a vector and the core functionalty will try to maximize it for the whole network given the constraints that things need to fit together. I guess this can be thought of as creating the mask.

ActivationContribution happened to be a suitable wrapper for an implementation of one metric, but if you for example choose to not wrap layers in it, the default metric will just use the absolute magnitude of the weights instead.

darsnack commented 3 years ago

Thanks for the clarification. I'll spend some time playing around with the package to learn more about it.

DrChainsaw commented 3 years ago

Awesome! Please post issues if you find anything wonky or difficult to comprehend!

DrChainsaw commented 3 years ago

Link to the pruning example from the ONNX issue: https://github.com/FluxML/ML-Coordination-Tracker/issues/10#issuecomment-717583714

DrChainsaw commented 3 years ago

Here is a take on a very simple “Package” for structured pruning using NaiveNASflux.

I have made some updates to NaiveNASlib so it is possible to just "mark" neurons to be pruned by giving them a negative value. Note that a negative value by itself is not a guarantee for pruning as things like residual connections could make the net value of keeping a negative valued neuron positive due to how elementwise addition ties neurons from multiple layers together.

It has two main functions

  1. prune(model, fraction) which prunes approximately fraction of the neurons by removing the corresponding parameters
  2. pruneuntil(model, accept; depth=6) which simply uses prune to try to find the most amount of pruning where accept(pruned_model) returns true in depth number of attempts.
Module with some extra code for experiments

```julia (PruningExample) pkg> add Statistics, MLDatasets NaiveNASflux, https://github.com/DrChainsaw/ONNXmutable.jl module PruningExample using ONNXmutable, NaiveNASflux, Statistics import MLDatasets export prune, pruneuntil function pruning_metric(v, offs) val = neuron_value(v) # neuron_value defaults to magnitude of parameters along activation dimension ismissing(val) && return zeros(nout_org(v)) # Layers with no parameters return missing by default return val .- offs end function prune(g::CompGraph, fraction) @assert 0 < fraction < 1 "Fraction of neurons to prune must be between 0 and 1" gnew = copy(g) # We don't want to change the number of outputs from the model, so exclude all layers for which a change in number of neurons leads to a change in model output size pvs = prunable(gnew) # Find the fraction neurons with smallest value allvals = mapreduce(neuron_value, vcat, pvs) |> skipmissing |> collect cutoff = partialsort(allvals, round(Int, fraction*length(allvals))) # Prune the model Δoutputs(OutSelectRelaxed() |> ApplyAfter, gnew, v -> v in pvs ? pruning_metric(v, cutoff) : fill(100, nout(v))) return gnew end prunable(g::CompGraph) = mapreduce(prunable, vcat, g.outputs) function prunable(v, ok = false) vs = mapreduce(vcat, inputs(v); init=AbstractVertex[]) do vi prunable(vi, ok || isabsorb(v)) end ok ? unique(vcat(v, vs)) : unique(vs) end isabsorb(v) = isabsorb(trait(v)) isabsorb(t::DecoratingTrait) = isabsorb(base(t)) isabsorb(t::SizeAbsorb) = true isabsorb(t::MutationSizeTrait) = false function pruneuntil(g::CompGraph, accept; depth = 6) # Binary search how much we can prune and still meet the acceptance criterion step = 1 fraction = 0.5 gaccepted = g faccepted = 0.0 while step < 2^depth @info "Prune $fraction of parameters" g′ = prune(g, fraction) step *= 2 if accept(g′) faccepted = fraction gaccepted = g′ fraction += fraction / step else fraction -= fraction / step end end return gaccepted end # Auxiallary stuff to run the experiment export resnet, faccept, nparams const resnetfile= Ref{String}("") function resnet() if !isfile(resnetfile[]) # I couldn't find any SOTA ONNX models for CIFAR10 online. # This is my not very successful attempt at replicating these experiments: https://github.com/davidcpage/cifar10-fast/blob/master/experiments.ipynb # Test accuracy is around 92% iirc resnetfile[] = download("https://github.com/DrChainsaw/NaiveGAExperiments/raw/master/lamarckism/pretrained/resnet.onnx") end return CompGraph(resnetfile[]) end function cifar10accuracy(model, batchsize=16; nbatches=cld(10000, batchsize)) x,y = MLDatasets.CIFAR10.testdata() itr = Flux.Data.DataLoader((x, Flux.onehotbatch(y, sort(unique(y)))); batchsize); xm = mean(x) |> Float32 xs = std(x; mean=xm) |> Float32 mean(Iterators.take(itr, nbatches)) do (xb, yb) xb_std = @. (Float32(xb) - xm) / xs sum(Flux.onecold(model(xb_std)) .== Flux.onecold(yb)) end / batchsize end function faccept(model) # I don't have a GPU on this computer so in this example I'll just use a small subset of the test set acc = cifar10accuracy(model, 32; nbatches=10) @info "\taccuracy: $acc" return acc > 0.9 end nparams(m) = mapreduce(prod ∘ size, +, params(m)) end ```

Example which prunes an imported (quite poor tbh) CIFAR10 model as much as possible while staying above 90% accuracy on the test set:

julia> using PruningExample

julia> f = resnet();

julia> nparams(f)
6575370

julia> f′ = pruneuntil(f, faccept);
[ Info: Prune 0.5 of parameters
[ Info:         accuracy: 0.478125
[ Info: Prune 0.25 of parameters
[ Info:         accuracy: 0.9125
[ Info: Prune 0.3125 of parameters
[ Info:         accuracy: 0.890625
[ Info: Prune 0.2734375 of parameters
[ Info:         accuracy: 0.909375
[ Info: Prune 0.29052734375 of parameters
[ Info:         accuracy: 0.909375
[ Info: Prune 0.2996063232421875 of parameters
[ Info:         accuracy: 0.89375

julia> nparams(f′) # Close enough to 29% fewer parameters 
4926056

Note how this is one slippery slope on the dark path to NAS as one immediately starts thinking things like "what if I retrain the model just a little after pruning" and then "maybe I should try to increase the size and see if things get better, or why not add some more layers and remove the ones which perform poorly and...". Disclaimer: I toy with NAS in my spare time as a hobby and I'm not wasting anyones money (except my own electricty bill) on it.

darsnack commented 3 years ago

Currently being worked on in MaskedArrays.jl and FluxPrune.jl.

DrChainsaw commented 3 years ago

Looks very cool!

Just to expose my ignorance on the subject: What is the end goal after masking the parameters?

For example, it is not clear to me if there are any benefits of having some fraction of parameters masked. Or should one convert them to sparse arrays if the amount of masking is above some theshold when it is beneficial? Does that give benefits on a GPU? Or does one just learn that a smaller model works and then builds and retrains the smaller model? Does not seem to be doable with unstructured pruning, or?

darsnack commented 3 years ago

Yeah that part is currently missing, but once you have finished pruning, you have a "freeze" step that turns all the masked arrays into a compressed form. For unstructured that could be a sparse array. For structured that could be "re-building" the model after dropping channels.

Probably, I will make the "freeze" step call MaskedArrays.freeze which just turns each masked array into the corresponding Array with zeros. I might include some common helpers, but I will leave the decision of going to sparse formats, etc. to the user. The reason is that AFAIK there is no standard next step. How to most effectively take advantage of zeros is highly hardware dependent. So, it is up to the person pruning a model to decide how best to take advantage of the pruned result.

DrChainsaw commented 3 years ago

Makes sense to me, guess I wasn't too far off base then.

For structured that could be "re-building" the model after dropping channels.

Just a word of warning to keep you from going insane: This is harder than it seems as the change propagates to the next layer and if that is things like concatenations and/or elementwise operations things get out of hand quickly. This is basically what I tried to do with NaiveNASlib after having some success with a simple version in another project and it ended with me throwing in the towel and reaching for JuMP when trying to make use of it in a more or less unrestricted NAS setting.

darsnack commented 3 years ago

Yeah the structured pruning literature special cases the handling of skip connections for this reason. They don't provide a general algorithm for all kinds of operations, and I don't intend to come up with one as a first pass. For now, I am just going to implement what's already out there, and hopefully throw a useful error when a user does some pruning that would result in mismatched dimensions.

DrChainsaw commented 3 years ago

I haven't seen the MIP approach in litterature and I must say I am pretty happy with the result.

I guess the drawback is that there is still a large set of opertations which can't use any of the formulations which NaiveNASlib ships and then one must write the constraints for each such operation (e.g. depthwise/grouped convolutions). I don't think this is much easier to handle with any other approach either (except maybe the linearization parts).

darsnack commented 3 years ago

I haven't seen the MIP approach in litterature and I must say I am pretty happy with the result.

Yes, I always thought this was a really cool show case of the power of Julia!

darsnack commented 2 years ago

The initial sketch of the library is FluxPrune.jl. I am going to close this issue in favor of opening specific issues over on that repo. I also updated https://github.com/FluxML/Flux.jl/issues/1431 to reflect this.