Open oschulz opened 3 years ago
How would you feel about depending on ParameterHandling once #42 goes in?
How would you feel about depending on ParameterHandling once #42 goes in?
I think even with #42 it think it would still be too heavy - the idea is to keep ChangesOfVariables super-lightweight, so we can convince packages to depend on it, and the stuff in parameters.jl
is really a different thing - here, I think, ParameterHandling would depend on ChangesOfVariables, not the other way round. And ParameterHandling is a more opinionated package than ChangesOfVariables.
But now that I think about it, since it looks like we need flatten/unflatten functionality in ChangesOfVariables, how would you like moving that there? It's not much more than I have in the draft of JuliaMath/ChangesOfVariables.jl#2 anyway, and I think it would fit into the theme of the package - it is a change of variables, after all (with a logabsdet-jacobian of 0, of course).
@willtebbutt and @devmotion, what do you think?
Update: ParameterHandling will likely depend on ChangesOfVariables indirectly in the future due to #42, since we hope to get support for ChangesOfVariables into LogExpFunctions (esp for logistic
, logit
etc.). So ChangesOfVariables must not depend on ParameterHandling.
I'm not sure, it's not clear to me that this functionality should be included in https://github.com/JuliaMath/ChangesOfVariables.jl/pull/2. I would prefer if such wrapping/reshaping/flattening etc. is handled by the user.
Yes, the more I think about it, it's really a bit of a different thing. And we could use test_with_logabsdet_jacobian(f, x, torealvec_and_back, getjacobian)
, that would keep the whole thing out of the dependencies of ChangesOfVariables.
I still think it's a very fundamental functionality that may deserve a central package independent from variable-domain transformations.
I wonder if we should add an interface InverseFunctions.with_inverse(f, x)
with a default implementation of InverseFunctions.with_inverse(f, x) = (f(x), inverse(f))
, to support use cases like flatten/unflatten that can't do inverse(f)
without a value (they would specialize with_inverse
)?
That would form a nice basis for a package that depends on InverseFunctions
and defines a function realvec
and with_inverse(::typeof(realvec), x)
. There are definitely use cases where one would only want to forward function and instantiating the back function would be wasteful (mem allocs). This way, we'd have an API than can do flatten and flatten-and-back.
In principle I could get on board with this kind of thing.
My main concern is that if we choose to make this functionality more widely available it'll be difficult to get consensus on some semantics. For example, flatten(1)
returns an empty vector and a closure. The rationale being that this is what you want if you're doing AD and 1
is a primal (as opposed to a tangent vector), because integers aren't differentiable. I can well imagine that someone somewhere would (quite reasonably) find this choice objectionable, and want different behaviour.
So I would want to generalise what we have here a little bit to allow different semantics depending on your situation. Maybe utilising a trait-based mechanism like the one used for RuleConfig
s in ChainRulesCore?
So you would write something like
struct Flattener{T} end
struct KeepIntegers end
# These have different behaviour.
with_inverse(Flattener{}(), 1)
with_inverse(Flattener{KeepIntegers}(), 1) = [1.0]
in case you want to work around this stuff?
Yes, thought some more about this today and also started to worry about integers. With the autodiff use case, as an extension to excluding integers, people may also want to exclude other parameters from the gradient calculation (if they are not relevant, to increase performance), so one might want to have a way to mark parameters as active/constant. That could quickly go beyond "non-opinionated central API" indeed ...
so one might want to have a way to mark parameters as active/constant.
We actually already have a way to mark parameters as being constant (see fixed
), although doing things the other way around (marking things as "active") is something that we don't currently have support for -- although we could almost certainly achieve using a trait-based system.
I think it would be nice to explore these options in ParameterHandling (especially since it's going to me more lightweight soon). If we add InverseFunctions.with_inverse
(JuliaMath/InverseFunctions.jl#6 - @devmotion?), your Flattener
concept could be both powerful and user-friendly, I think.
A central "flatten-only" package could come later then, maybe? As soon as #27 is sorted out, people will already be able to ParameterHandling in ChangesOfVariables.test_with_logabsdet_jacobian
(changed now), without ChangesOfVariables depending on ParameterHandling. And if we go for the with_inverse
approach, I would simplify the rv_and_back = x -> (x, identity)
option to transform = identity
- much more elegant.
We actually already have a way to mark parameters as being constant (see fixed)
Neat!
One advantage with the InverseFunctions.with_inverse
approach would be that Flattener()
itself would just return the value, not a tuple of value and back-function. So it could support ChangesOfVariables.with_logabsdet_jacobian
. While Flattener()
would have a trivial volume-element, it could be interesting for ValueFlattener()
as a successor of value_flatten
.
@devmotion pointed me to https://github.com/JuliaDiff/FiniteDifferences.jl/blob/main/src/to_vec.jl
There's probably equivalents of this in several places in the ecosystem, internally. If we had a nice, lightweight central API for a to-real-vec-and-back, people could make their types support it, that would make a lot of AD a bit easier. It's handling integers and constant parameters though that makes it less straightforward than I thought initialy - but I think it could (and maybe should) still be decoupled from actual value transformations.
CC @mzgubic @oxinabox
Just to add that in the ChainRules ecosystem there is a vague plan for moving away from to_vec
.
I think there is another copy of this general idea somewhere in ArrayInterface.jl? cc @chrisrackauckas
@oschulz : There are 2 issues that come to my mind:
1) How do you handle the unflatten part? If AD is in mind, then the output type is determined by the input type, but for pretty much anything else, the return type should be given by the initial container type that was a argument in flatten. We had a quick discussion here: https://github.com/invenia/ParameterHandling.jl/pull/39 . I believe the main use case for ParameterHandling.jl is the flatten part, so in this case that would be not much of an issue.
2) How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments, but at the moment they are not flattened. I think this is correct, because the unflatten part would cause all sorts of problems (Integers and Floats would probably be flattened to Floats), but I can see arguments against this case as well.
How do you handle the unflatten part? If AD is in mind, then the output type is determined
Hm, I guess in many cases, AD would just run on the flattened result - it would then (e.g.) be combined with the flat-gradient and then reconstructed into the original type. So with ForwardDiff
there wouldn't be a problem. Reverse mode though - good question.
How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments
Yes, integers and values that are supposed to be constant in general. That's the main challenge, I think - AD use-cases will typically want to assume that integers be constant, other use cases may see this differently. ParameterHandling has fixed
for explicit control.
I don't have ready answers, I have to admit. But the fact that we have this "flatten-to-real" in so many places seems like a good motivation to have a lightweight central package that people would be willing to depend on (separate from value domain transformations). Once we have figured out the answers to those questions, that is. :-)
For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening: https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl. We opted for the conservative approach and only consider non-integer numeric arrays (due in part to Flux's design constraints). Something like fixed
would be interesting, but the main challenge with such a wrapper is making it transparent to non-parameter handling code.
For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening
This is build on top of Functors.jl mainly, right?
Yes, Functors.jl is an integral part of it but there's no reason a similar set of functionality couldn't be developed for another parameter handling library :)
While revamping ForwardDiffPullbacks, I had to add "invent" yet another flatten/unflatten mechanism. To get full performance and type stability (this needs to be really fast and allocation-free with deeply nested structures, flattening to static vectors), I ended up with
flatten(x)
unflatten(x_orig, x_flat)
unflatten_tangent(x_orig, dx_flat)
An approach like x_flat, reconstruct_function = flatten(x)
didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also the x_flat, re = ...
pattern didn't provide tangent-unflatten capability.
The problem is that with these flatten/unflatten capabilities "hidden" in several packages in different ways, it's near impossible for users to specify flatten/unflatten for types that need special handling without creating a dependency nightmare. I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?
An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise
I guess you could make methods in case reconstruct_function
is not needed.
I think I should link to the current discussion regarding ConstructionBase.getfields
/ConstructionBase.getproperties
(JuliaObjects/ConstructionBase.jl#54) here, I feel these issues are very connected.
An approach like
x_flat, reconstruct_function = flatten(x)
didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also thex_flat, re = ...
pattern didn't provide tangent-unflatten capability.
The implementation isn't type stable because of unrelated factors (mostly caching in Functors.jl), but Optimisers.destructure
does basically this. Saving some auxiliary state during flattening pays off when # reconstructions > # of flattenings. What does not work well is returning a plain closure, but I'm preaching to the choir here :)
I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?
I've thought about this and it's just genuinely hard because of how broad and nuanced "(un)flattening" is. Some examples to chew on:
ParameterHanding.fixed
and Optimisers.trainable
. The latter brings up another point: for many applications we can not assume that all parameter-visiting traversals of an object tree(/DAG) will care about the same set of fields flattening does. This is the motivation behind StructWalk.jl's WalkStyle
.I agree, it's a tricky, multi-faceted problem. But I still think we have too many competing solutions in the ecosystem right now. Maybe one could somehow factorize this into an API for struct-developers that allows them to "annotate" their structs so we need less heuristics/guesswork, and a set of flatten/unflatten APIs for engine/algorithms-developers that make use of those "annotations"? I don't have a concrete proposal, I just feel that some ChainRulesCore-like (in spirit, not functionality) standard in this area is really missing in the ecosystem (or possibly the language itself) at the moment.
AFAIK something like https://github.com/rafaqz/FieldMetadata.jl could be that standard, but the note at the top of the README seems to suggest differently :/
@rafaqz can we pull you in here as well?
FieldMetadata.jl is a cool idea, but it sets metadata by defining new functions on an object. So you have method table state to think about if you ever want to change anything during use. Its also a lot of confusing and fragile macros for people to understand, that dont scale to organisation/research group level use very well.
ModelParameters.jl is a better solution. You can do most of the same things but state is contained in the object, and has a Tables.jl interface.
What do you need from this package that you can't do with ModelParameters?
Additionally, truly lightweight recursive rebuilding is possible with my PR to Accessors.jl. And thats probably the most generic place to put it. It hasn't been merged because Im too busy and we can't get it type stable because of issues with Base no longer inferring recursive methods.
What do you need from this package that you can't do with ModelParameters?
If you have a type that needs a bit of special handling, and you want it to be compatible with a wide variety of Julia ML, statistics & friends packages, and you want GPU + GPU support, you currently have to depend on and specialize functionality defined in: Adapt, ConstructionBase, Functors, ParameterHandling, and possibly a few others. To my knowledge, none of those draw on each other for their default implementations. This doesn't compose well and forces packages to either take on a lot of dependencies or not be compatible with parts of the ecosystem. And it can result in a lot of boilerplate code.
As a type developer, I'll probably just want to implement a simple, closure-free API like
get_raw_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
get_semantic_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
reconstruct_from(::Type{<:MyType}, ::ResultOfGetRawContents)
reconstruct_from(::Type{<:MyType}, ::ResultOfGetSemanticContents)
The contract here would be
reconstruct_from(typeof(x), get_raw_contents(x) == x
reconstruct_from(typeof(x), get_semantic_contents(x) == x
reconstruct_from
should accept contents with a different numeric precision and array types as the original x
, of course, if supported by the type.
A package like ConstructionBase would seem a natural place to host such an API.
get_raw_contents
and get_semantic_contents
could return tuples, NamedTuple
s, Real
s (even all Number
s?) and Array
s as-is, and return a NamedTuple
for structs - get_raw_contents
could use fieldcount
/fieldname/
getfieldand
get_semantic_contentscould use
propertynames/
getproperty`.
A package like Adapt would probably want to use get_raw_contents
for it's default implementation, whereas packages like Functors and ParameterHandling would want get_semantic_contents
I guess.
Yes ConstructionBase.jl is the natural place for this, and sharing these base methods across more packages was the original reason for it to be written. Flatten.jl, Setfield.jl and I think BangBang.jl all needed it. Fixing any inadequacies it has to make it useful in these other packages is surely within scope.
In your schema I guess getrawcontents
== ConstuctionsBase.getfields
(in the current PR by @jw3126) and get_semantic_contents
== ConstructionBase.getproperties
?
Also, Adapt.jl is already kind of redundant. You can replace it with Flatten.jl or Accessors.jl with that PR. I do this in some GPU based packages like DynamicGrids.jl, but the current problem with types stability in base makes this less of an option than it used to be.
Also, Adapt.jl is already kind of redundant.
I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(
In your schema I guess getrawcontents == ConstuctionsBase.getfields (in the current PR by @jw3126) and get_semantic_contents == ConstructionBase.getproperties ?
Yes, that was my idea. If there's room for such a lightweight, closure-free API in ConstructionBase I'd be very happy to pitch in!
What do you need from this package that you can't do with ModelParameters?
My understanding of ModelParameters is that model struct types must be able to accept Param
s as fields. For philosophical and practical reasons, Flux can not require model structs to use framework-defined types for the purpose of parameter tracking. Instead, this information is currently kept out-of-band via functions like functor
or trainable
.
Now this isn't necessarily set in stone, but I've yet to see a satisfactory solution to the rewrite-the-world-to-work-with-[param wrapper type(s)] problem.
I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(
In defense of Adapt, I think the type stability is more important than the potentially lost flexibility. There's also the deeper question of whether (un)flattening is the right implementation strategy for something like adapt_structure
. In https://github.com/FluxML/Functors.jl/pull/27, I tried exploring what would happen if we used a more FP-inspired structural map (think Flatten.modify
) as the core primitive. This turned out to have its own set of challenges, but the point is that both the problem and solution space are quite broad!
@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents
for the same type?
A concrete motivating example may be found in a NN module system. Getting all parameters of a layer may be handled by get_raw_contents
, but we'd also want get_trainable_contents
for optimization and get_trivially_serializable_contents
for saving layer state.
how would you foresee handling multiple different versions of get_semantic_contents for the same type?
IMO multiple get_semantic_contents
on the same type are not the job of ConstructionBase. Analogous to Base not providing multiple variants of getproperty
on a single type. Instead I think lenses are a great abstraction for this, so you could use Accessors.jl or Setfield.jl
@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents for the same type?
Thanks!
Multiple different versions of get_semantic_contents for the same type in what respect? Could you give a quick example?
I put a relevant example in the comment above, but the gist is that not all traversals we make over a nested object DAG tree (I'm being cheeky here, but handling "shared" nodes is a big part of the Functors codebase) will want to access the same set of fields. Thus defining a single get_semantic_contents
function has limited value for our use cases unless the definition of that function is sufficiently general. But then if it's too general, there doesn't seem to be much of a difference between get_semantic_contents
and get_raw_contents
, AIUI.
This is why I mentioned StructWalk earlier: by parameterizing the (un)flatten function with the type of traversal being performed (WalkStyle
), the traversal code can be decoupled from the code that extracts the contents of each node. Generic fallbacks mean that no convenience is lost for types that return the same contents regardless of the purpose of the traversal. Is it possible to do something similar with lenses?
but handling "shared" nodes is a big part of the Functors codebase
That would be part of the nested walk scheme relevant to the application though, right, and orthogonal to the "get content of this type" API?
by parameterizing the (un)flatten function with the type of traversal being performed
Ah, I think I get it @ToucheSir . So you mean instead of having get_raw_contents(x)
and get_semantic_contents(x)
we'll need something like get_contents(x, decomposition_style)
, with a matching reconstruct_from(T, content, decomposition_style)
?
Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?
With such a proposed closure-free API in ConstructionBase, would ParameterHandling adopt it? And what would be needed (see discussion at end of ConstructionBase.jl#54).
Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?
Yes. The default could be identical to get_semantic_contents(x)
. I'll pick up the rest on the ConstructionBase thread.
In the spirit of creating lightweight interface-defining packages (see TuringLang/Bijectors.jl#199 which resulted in InverseFunctions.jl and ChangesOfVariables.jl):
While adding an interface-test utility to ChangesOfVariables.jl today, I needed a helper function _to_realvec_and_back - which does exactly what
ParameterHandling.flatten
does. And I needed it for exactly what #27 is requesting. :-)It would be nice not to reinvent this, but
ParameterHandling
would of course be way to heavy a dependency forChangesOfVariables
- and even if it wasn't, we'd end up with a circular dependency sinceBijectors
will useChangesOfVariables
soon, soParameterHandling
will too.I think the flatten/unflatten functionality of
ParameterHandling
does something very fundamental, and is orthogonal to it's variable transformation capabilities. Would you consider splitting it out into a lightweight package (basically the contents offlatten.jl
)?A truly lightweight recursive flatten/unflatten interface package could IMHO find use in many places in the ecosystem.
flatten
may be a bit too generic a name for the function (we probably have severalflatten
s in the ecosystem), but how aboutflatten_and_back
or so?CC @willtebbutt, @paschermayr, @devmotion
Update:
ChangesOfVariables.test_with_logabsdet_jacobian
now has an additional optional argument to pass a transformation, which solves the dependency problem. Users will be able to use ParameterHandling for variable transformations during the test without a direct dependency between the two packages (after #27 is solved).