FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 602 forks source link

`Flux.params` for primitive types #1991

Open FelixBenning opened 2 years ago

FelixBenning commented 2 years ago

Flux.params does not work for primitive types, e.g.

Screenshot 2022-06-07 at 10 14 45

I assume this is because Zygote uses the pointer to the arrays to identify the array and differentiate with regard to every usage of that pointer. Primitive types do not have a pointer stored in the struct but the element itself so this method breaks.

Ideally Primitive types are supported (the pointer is the pointer to the model struct + offset)

But it might be more realistic to use the output of Flux.functor(model) to print a warning about any primitive (or otherwise incompatible) types when Flux.params is used.

mcabbott commented 2 years ago

Yes, Zygote's implicit mode does not track scalars. It works by objectid, which is a stable identity for things like arrays. The plan is for Flux to stop using this, in favour of explicit mode (things like gradient(m -> loss(m, x, y), model), without Params / Grads) which has no such restriction, and no global variables.

The issue about this transition is #1986. At the moment, Optimisers.jl won't update scalars, it only acts on arrays of numbers. This can be widened, the issue is that Functors.jl again uses objectid to decide whether two branches of the tree are identical, thus several scalars initially 0.0 will be permanently locked together. https://github.com/FluxML/Functors.jl/pull/39 is one attempt to fix this.

might be more realistic to use the output of Flux.functor(model) to print a warning about any primitive (or otherwise incompatible) types

The reason it doesn't, BTW, is that models do often contain numbers we don't want to regard as parameters -- such as the strides of a convolution.