beacon-biosignals / LegolasFlux.jl

Save Flux model weights in Legolas-powered Arrow tables
MIT License
6 stars 1 forks source link

Workaround Flux#1027 #4

Closed ericphanson closed 3 years ago

ericphanson commented 3 years ago

This example very much does not work:

julia> output
10×1 Matrix{Float32}:
 1.2540959f-6
 5.1405354f-5
 0.00020041084
 3.131654f-5
 5.9973813f-6
 4.5418696f-7
 5.347429f-8
 0.99964094
 3.7541522f-5
 3.067953f-5

julia> output2
10×1 Matrix{Float32}:
 0.086956024
 0.10940457
 0.09150873
 0.098902285
 0.088835925
 0.06954686
 0.06539499
 0.20102207
 0.090939134
 0.0974893

The goal of this PR is to fix things so that we can correctly (de)-serialize this model so the test passes.

codecov[bot] commented 3 years ago

Codecov Report

Merging #4 (c9f5b61) into main (002298b) will increase coverage by 3.32%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main       #4      +/-   ##
==========================================
+ Coverage   90.90%   94.23%   +3.32%     
==========================================
  Files           1        2       +1     
  Lines          33       52      +19     
==========================================
+ Hits           30       49      +19     
  Misses          3        3              
Impacted Files Coverage Δ
src/LegolasFlux.jl 91.17% <ø> (+0.26%) :arrow_up:
src/functors.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 002298b...c9f5b61. Read the comment docs.

ericphanson commented 3 years ago

This works around https://github.com/FluxML/Flux.jl/issues/1027 by essentially following the strategy of https://github.com/FluxML/Flux.jl/pull/1509#issuecomment-798947090, where here I've called @touchesir's actually_all_params just weights and added loadweights! as an analog to loadparams! that uses these.

This adds a Flux dependency here which is a little unfortunate because downstream code might want to use the LegolasFlux schema without actually touching weights/flux stuff (e.g. looking at a bunch of losses or something), and now such code will pull in the Flux dependency. But when upstream fixes #1027 we can probably remove the flux_workarounds.jl file (and make a breaking release).

Thanks to @hannahilea for some pair debugging on this!

ericphanson commented 3 years ago

Thanks for the reviews!

Thanks especially for the pointers to fcollect and Functors, @ToucheSir! Super helpful. In the last two commits, I switched to using that approach and dropped the Flux dep (for the very light Functors dep). I found that fcollect as-is does not work because of this line: https://github.com/FluxML/Functors.jl/blob/adeb24bc3b2fb3e9959f1157d81f4633a855e207/src/functor.jl#L109. x in arr for arrays arr compares by value, and if we end up with several xs with the same values (say zero'd out arrays from a freshly initialized model or something like that), then we only keep 1 of each, and when it comes time to load the weights later, we don't have enough of them (this isn't hypothetical, it happened with the DigitsModel example here and had me confused for awhile). So instead in fcollect2 I switched to using an IdSet to keep track of the objects we've already added (similarly to Functors.fmap). Since an IdSet is unordered and we definitely care about the ordering, I had to also keep the original Vector{Any} used to collect the weights. Keeping them in two data structures is a little inelegant (I guess I actually want an OrderedIdSet which doesn't exist AFAIK) but should have a negligible perfomance cost in this context (since we just store references to the arrays, not copies of them).

Also, since we now have loadweights! available in the package, we might be able to make the API a little easier to use (e.g. allow a user to pass a model directly instead of making them get the weights out first themselves). But I think that should be followup work (I think it requires a little more thought about the best approach).

ToucheSir commented 3 years ago

Nice catch for fcollect. IMO this also a bug in the upstream implementation and we should fix it to work like fcollect2. Will have a look at that when I next find time to work on Functors.

ericphanson commented 3 years ago

Ok good to hear this is likely a bug! I've filed an issue, https://github.com/FluxML/Functors.jl/issues/16.

ToucheSir commented 3 years ago

(Too bad params is already taken by Flux!)

Not only that, the name params is misleading because it only returns trainable parameters. Moreover, it conflicts with BenchmarkTools.params as well. I wish we could rename it to something else, but it would be a massive back compat break because params is arguably the most used function in all of Flux (not counting re-exports from Zygote etc.)

ericphanson commented 3 years ago

Ok, we've got some bikeshedding to do! Some ideas...

Getter Setter Column name
Current weights load_weights! weights
state fetch_state load_state! state
learnings fetch_learnings load_learnings! learnings

We're leaning towards state but if anyone has any further ideas or comments let me know! I will pick something and merge tomorrow morning if there aren't other objections.

hannahilea commented 3 years ago

...pytorch uses state:

In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used. (https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)

...tensorflow uses weights (huh):

A Keras model consists of multiple components:

The architecture, or configuration, which specifies what layers the model contain, and how they're connected. A set of weights values (the "state of the model"). (https://www.tensorflow.org/guide/keras/save_and_serialize#introduction)

so maybe I sent us down this bikeshed prematurely, and it is okay to stick with weights after all... 🤷

ericphanson commented 3 years ago

Ok cool! Thanks for looking up how others do it, that was a great idea.

I think then since weights seems acceptable I'd like to go with that. We only allow (multi-dimensional) arrays with all the same element type, which is far from arbitrary state. That restriction is on purpose of course, since we are trying to provide a different way to serialize the model rather than include arbitrary state from your julia session, so I think it's good to highlight it a little. Also, it's nice to keep this non-breaking by not adjusting the schema.

But I think I'd like to use fetch_weights instead of weights so that we have fetch_weights, load_weights!, and column name weights. I like that better than reusing weights for both the function and the column name because you can then do

weights = fetch_weights(model)

row = ModelRow(; weights, ...)

I'll add an explanation to the readme to say exactly what we mean by weights here.