Closed ericphanson closed 3 years ago
Merging #4 (c9f5b61) into main (002298b) will increase coverage by
3.32%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## main #4 +/- ##
==========================================
+ Coverage 90.90% 94.23% +3.32%
==========================================
Files 1 2 +1
Lines 33 52 +19
==========================================
+ Hits 30 49 +19
Misses 3 3
Impacted Files | Coverage Δ | |
---|---|---|
src/LegolasFlux.jl | 91.17% <ø> (+0.26%) |
:arrow_up: |
src/functors.jl | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 002298b...c9f5b61. Read the comment docs.
This works around https://github.com/FluxML/Flux.jl/issues/1027 by essentially following the strategy of https://github.com/FluxML/Flux.jl/pull/1509#issuecomment-798947090, where here I've called @touchesir's actually_all_params
just weights
and added loadweights!
as an analog to loadparams!
that uses these.
This adds a Flux dependency here which is a little unfortunate because downstream code might want to use the LegolasFlux schema without actually touching weights/flux stuff (e.g. looking at a bunch of losses or something), and now such code will pull in the Flux dependency. But when upstream fixes #1027 we can probably remove the flux_workarounds.jl
file (and make a breaking release).
Thanks to @hannahilea for some pair debugging on this!
Thanks for the reviews!
Thanks especially for the pointers to fcollect
and Functors, @ToucheSir! Super helpful. In the last two commits, I switched to using that approach and dropped the Flux dep (for the very light Functors dep). I found that fcollect
as-is does not work because of this line: https://github.com/FluxML/Functors.jl/blob/adeb24bc3b2fb3e9959f1157d81f4633a855e207/src/functor.jl#L109. x in arr
for arrays arr
compares by value, and if we end up with several x
s with the same values (say zero'd out arrays from a freshly initialized model or something like that), then we only keep 1 of each, and when it comes time to load the weights later, we don't have enough of them (this isn't hypothetical, it happened with the DigitsModel example here and had me confused for awhile). So instead in fcollect2
I switched to using an IdSet
to keep track of the objects we've already added (similarly to Functors.fmap
). Since an IdSet
is unordered and we definitely care about the ordering, I had to also keep the original Vector{Any}
used to collect the weights. Keeping them in two data structures is a little inelegant (I guess I actually want an OrderedIdSet
which doesn't exist AFAIK) but should have a negligible perfomance cost in this context (since we just store references to the arrays, not copies of them).
Also, since we now have loadweights!
available in the package, we might be able to make the API a little easier to use (e.g. allow a user to pass a model directly instead of making them get the weights out first themselves). But I think that should be followup work (I think it requires a little more thought about the best approach).
Nice catch for fcollect
. IMO this also a bug in the upstream implementation and we should fix it to work like fcollect2
. Will have a look at that when I next find time to work on Functors.
Ok good to hear this is likely a bug! I've filed an issue, https://github.com/FluxML/Functors.jl/issues/16.
(Too bad
params
is already taken by Flux!)
Not only that, the name params
is misleading because it only returns trainable parameters. Moreover, it conflicts with BenchmarkTools.params
as well. I wish we could rename it to something else, but it would be a massive back compat break because params
is arguably the most used function in all of Flux (not counting re-exports from Zygote etc.)
Ok, we've got some bikeshedding to do! Some ideas...
Getter | Setter | Column name | |
---|---|---|---|
Current | weights |
load_weights! |
weights |
state | fetch_state |
load_state! |
state |
learnings | fetch_learnings |
load_learnings! |
learnings |
We're leaning towards state
but if anyone has any further ideas or comments let me know! I will pick something and merge tomorrow morning if there aren't other objections.
...pytorch uses state
:
In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used. (https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
...tensorflow uses weights
(huh):
A Keras model consists of multiple components:
The architecture, or configuration, which specifies what layers the model contain, and how they're connected. A set of weights values (the "state of the model"). (https://www.tensorflow.org/guide/keras/save_and_serialize#introduction)
so maybe I sent us down this bikeshed prematurely, and it is okay to stick with weights
after all... 🤷
Ok cool! Thanks for looking up how others do it, that was a great idea.
I think then since weights
seems acceptable I'd like to go with that. We only allow (multi-dimensional) arrays with all the same element type, which is far from arbitrary state. That restriction is on purpose of course, since we are trying to provide a different way to serialize the model rather than include arbitrary state from your julia session, so I think it's good to highlight it a little. Also, it's nice to keep this non-breaking by not adjusting the schema.
But I think I'd like to use fetch_weights
instead of weights
so that we have fetch_weights
, load_weights!
, and column name weights
. I like that better than reusing weights
for both the function and the column name because you can then do
weights = fetch_weights(model)
row = ModelRow(; weights, ...)
I'll add an explanation to the readme to say exactly what we mean by weights
here.
This example very much does not work:
The goal of this PR is to fix things so that we can correctly (de)-serialize this model so the test passes.