JuliaGaussianProcesses / ParameterHandling.jl

Foundational tooling for handling collections of parameters in models
MIT License
72 stars 11 forks source link

Extracting flatten/unflatten into lightweight package? #43

Open oschulz opened 2 years ago

oschulz commented 2 years ago

In the spirit of creating lightweight interface-defining packages (see TuringLang/Bijectors.jl#199 which resulted in InverseFunctions.jl and ChangesOfVariables.jl):

While adding an interface-test utility to ChangesOfVariables.jl today, I needed a helper function _to_realvec_and_back - which does exactly what ParameterHandling.flatten does. And I needed it for exactly what #27 is requesting. :-)

It would be nice not to reinvent this, but ParameterHandling would of course be way to heavy a dependency for ChangesOfVariables - and even if it wasn't, we'd end up with a circular dependency since Bijectors will use ChangesOfVariables soon, so ParameterHandling will too.

I think the flatten/unflatten functionality of ParameterHandling does something very fundamental, and is orthogonal to it's variable transformation capabilities. Would you consider splitting it out into a lightweight package (basically the contents of flatten.jl)?

A truly lightweight recursive flatten/unflatten interface package could IMHO find use in many places in the ecosystem. flatten may be a bit too generic a name for the function (we probably have several flattens in the ecosystem), but how about flatten_and_back or so?

CC @willtebbutt, @paschermayr, @devmotion

Update: ChangesOfVariables.test_with_logabsdet_jacobian now has an additional optional argument to pass a transformation, which solves the dependency problem. Users will be able to use ParameterHandling for variable transformations during the test without a direct dependency between the two packages (after #27 is solved).

willtebbutt commented 2 years ago

How would you feel about depending on ParameterHandling once #42 goes in?

oschulz commented 2 years ago

How would you feel about depending on ParameterHandling once #42 goes in?

I think even with #42 it think it would still be too heavy - the idea is to keep ChangesOfVariables super-lightweight, so we can convince packages to depend on it, and the stuff in parameters.jl is really a different thing - here, I think, ParameterHandling would depend on ChangesOfVariables, not the other way round. And ParameterHandling is a more opinionated package than ChangesOfVariables.

But now that I think about it, since it looks like we need flatten/unflatten functionality in ChangesOfVariables, how would you like moving that there? It's not much more than I have in the draft of JuliaMath/ChangesOfVariables.jl#2 anyway, and I think it would fit into the theme of the package - it is a change of variables, after all (with a logabsdet-jacobian of 0, of course).

@willtebbutt and @devmotion, what do you think?

Update: ParameterHandling will likely depend on ChangesOfVariables indirectly in the future due to #42, since we hope to get support for ChangesOfVariables into LogExpFunctions (esp for logistic, logit etc.). So ChangesOfVariables must not depend on ParameterHandling.

devmotion commented 2 years ago

I'm not sure, it's not clear to me that this functionality should be included in https://github.com/JuliaMath/ChangesOfVariables.jl/pull/2. I would prefer if such wrapping/reshaping/flattening etc. is handled by the user.

oschulz commented 2 years ago

Yes, the more I think about it, it's really a bit of a different thing. And we could use test_with_logabsdet_jacobian(f, x, torealvec_and_back, getjacobian), that would keep the whole thing out of the dependencies of ChangesOfVariables.

oschulz commented 2 years ago

I still think it's a very fundamental functionality that may deserve a central package independent from variable-domain transformations.

I wonder if we should add an interface InverseFunctions.with_inverse(f, x) with a default implementation of InverseFunctions.with_inverse(f, x) = (f(x), inverse(f)), to support use cases like flatten/unflatten that can't do inverse(f) without a value (they would specialize with_inverse)?

That would form a nice basis for a package that depends on InverseFunctions and defines a function realvec and with_inverse(::typeof(realvec), x). There are definitely use cases where one would only want to forward function and instantiating the back function would be wasteful (mem allocs). This way, we'd have an API than can do flatten and flatten-and-back.

willtebbutt commented 2 years ago

In principle I could get on board with this kind of thing.

My main concern is that if we choose to make this functionality more widely available it'll be difficult to get consensus on some semantics. For example, flatten(1) returns an empty vector and a closure. The rationale being that this is what you want if you're doing AD and 1 is a primal (as opposed to a tangent vector), because integers aren't differentiable. I can well imagine that someone somewhere would (quite reasonably) find this choice objectionable, and want different behaviour.

So I would want to generalise what we have here a little bit to allow different semantics depending on your situation. Maybe utilising a trait-based mechanism like the one used for RuleConfigs in ChainRulesCore?

So you would write something like

struct Flattener{T} end

struct KeepIntegers end

# These have different behaviour.
with_inverse(Flattener{}(), 1)
with_inverse(Flattener{KeepIntegers}(), 1) = [1.0]

in case you want to work around this stuff?

oschulz commented 2 years ago

Yes, thought some more about this today and also started to worry about integers. With the autodiff use case, as an extension to excluding integers, people may also want to exclude other parameters from the gradient calculation (if they are not relevant, to increase performance), so one might want to have a way to mark parameters as active/constant. That could quickly go beyond "non-opinionated central API" indeed ...

willtebbutt commented 2 years ago

so one might want to have a way to mark parameters as active/constant.

We actually already have a way to mark parameters as being constant (see fixed), although doing things the other way around (marking things as "active") is something that we don't currently have support for -- although we could almost certainly achieve using a trait-based system.

oschulz commented 2 years ago

I think it would be nice to explore these options in ParameterHandling (especially since it's going to me more lightweight soon). If we add InverseFunctions.with_inverse (JuliaMath/InverseFunctions.jl#6 - @devmotion?), your Flattener concept could be both powerful and user-friendly, I think.

A central "flatten-only" package could come later then, maybe? As soon as #27 is sorted out, people will already be able to ParameterHandling in ChangesOfVariables.test_with_logabsdet_jacobian (changed now), without ChangesOfVariables depending on ParameterHandling. And if we go for the with_inverse approach, I would simplify the rv_and_back = x -> (x, identity) option to transform = identity - much more elegant.

oschulz commented 2 years ago

We actually already have a way to mark parameters as being constant (see fixed)

Neat!

One advantage with the InverseFunctions.with_inverse approach would be that Flattener() itself would just return the value, not a tuple of value and back-function. So it could support ChangesOfVariables.with_logabsdet_jacobian. While Flattener() would have a trivial volume-element, it could be interesting for ValueFlattener() as a successor of value_flatten.

oschulz commented 2 years ago

@devmotion pointed me to https://github.com/JuliaDiff/FiniteDifferences.jl/blob/main/src/to_vec.jl

There's probably equivalents of this in several places in the ecosystem, internally. If we had a nice, lightweight central API for a to-real-vec-and-back, people could make their types support it, that would make a lot of AD a bit easier. It's handling integers and constant parameters though that makes it less straightforward than I thought initialy - but I think it could (and maybe should) still be decoupled from actual value transformations.

CC @mzgubic @oxinabox

mzgubic commented 2 years ago

Just to add that in the ChainRules ecosystem there is a vague plan for moving away from to_vec.

oxinabox commented 2 years ago

I think there is another copy of this general idea somewhere in ArrayInterface.jl? cc @chrisrackauckas

paschermayr commented 2 years ago

@oschulz : There are 2 issues that come to my mind:

1) How do you handle the unflatten part? If AD is in mind, then the output type is determined by the input type, but for pretty much anything else, the return type should be given by the initial container type that was a argument in flatten. We had a quick discussion here: https://github.com/invenia/ParameterHandling.jl/pull/39 . I believe the main use case for ParameterHandling.jl is the flatten part, so in this case that would be not much of an issue.

2) How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments, but at the moment they are not flattened. I think this is correct, because the unflatten part would cause all sorts of problems (Integers and Floats would probably be flattened to Floats), but I can see arguments against this case as well.

oschulz commented 2 years ago

How do you handle the unflatten part? If AD is in mind, then the output type is determined

Hm, I guess in many cases, AD would just run on the flattened result - it would then (e.g.) be combined with the flat-gradient and then reconstructed into the original type. So with ForwardDiff there wouldn't be a problem. Reverse mode though - good question.

How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments

Yes, integers and values that are supposed to be constant in general. That's the main challenge, I think - AD use-cases will typically want to assume that integers be constant, other use cases may see this differently. ParameterHandling has fixed for explicit control.

I don't have ready answers, I have to admit. But the fact that we have this "flatten-to-real" in so many places seems like a good motivation to have a lightweight central package that people would be willing to depend on (separate from value domain transformations). Once we have figured out the answers to those questions, that is. :-)

ToucheSir commented 2 years ago

For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening: https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl. We opted for the conservative approach and only consider non-integer numeric arrays (due in part to Flux's design constraints). Something like fixed would be interesting, but the main challenge with such a wrapper is making it transparent to non-parameter handling code.

oschulz commented 2 years ago

For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening

This is build on top of Functors.jl mainly, right?

ToucheSir commented 2 years ago

Yes, Functors.jl is an integral part of it but there's no reason a similar set of functionality couldn't be developed for another parameter handling library :)

oschulz commented 2 years ago

While revamping ForwardDiffPullbacks, I had to add "invent" yet another flatten/unflatten mechanism. To get full performance and type stability (this needs to be really fast and allocation-free with deeply nested structures, flattening to static vectors), I ended up with

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also the x_flat, re = ... pattern didn't provide tangent-unflatten capability.

The problem is that with these flatten/unflatten capabilities "hidden" in several packages in different ways, it's near impossible for users to specify flatten/unflatten for types that need special handling without creating a dependency nightmare. I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?

paschermayr commented 2 years ago

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise

I guess you could make methods in case reconstruct_function is not needed.

oschulz commented 2 years ago

I think I should link to the current discussion regarding ConstructionBase.getfields/ConstructionBase.getproperties (JuliaObjects/ConstructionBase.jl#54) here, I feel these issues are very connected.

ToucheSir commented 2 years ago

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also the x_flat, re = ... pattern didn't provide tangent-unflatten capability.

The implementation isn't type stable because of unrelated factors (mostly caching in Functors.jl), but Optimisers.destructure does basically this. Saving some auxiliary state during flattening pays off when # reconstructions > # of flattenings. What does not work well is returning a plain closure, but I'm preaching to the choir here :)

I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?

I've thought about this and it's just genuinely hard because of how broad and nuanced "(un)flattening" is. Some examples to chew on:

oschulz commented 2 years ago

I agree, it's a tricky, multi-faceted problem. But I still think we have too many competing solutions in the ecosystem right now. Maybe one could somehow factorize this into an API for struct-developers that allows them to "annotate" their structs so we need less heuristics/guesswork, and a set of flatten/unflatten APIs for engine/algorithms-developers that make use of those "annotations"? I don't have a concrete proposal, I just feel that some ChainRulesCore-like (in spirit, not functionality) standard in this area is really missing in the ecosystem (or possibly the language itself) at the moment.

ToucheSir commented 2 years ago

AFAIK something like https://github.com/rafaqz/FieldMetadata.jl could be that standard, but the note at the top of the README seems to suggest differently :/

oschulz commented 2 years ago

@rafaqz can we pull you in here as well?

rafaqz commented 2 years ago

FieldMetadata.jl is a cool idea, but it sets metadata by defining new functions on an object. So you have method table state to think about if you ever want to change anything during use. Its also a lot of confusing and fragile macros for people to understand, that dont scale to organisation/research group level use very well.

ModelParameters.jl is a better solution. You can do most of the same things but state is contained in the object, and has a Tables.jl interface.

What do you need from this package that you can't do with ModelParameters?

rafaqz commented 2 years ago

Additionally, truly lightweight recursive rebuilding is possible with my PR to Accessors.jl. And thats probably the most generic place to put it. It hasn't been merged because Im too busy and we can't get it type stable because of issues with Base no longer inferring recursive methods.

oschulz commented 2 years ago

What do you need from this package that you can't do with ModelParameters?

If you have a type that needs a bit of special handling, and you want it to be compatible with a wide variety of Julia ML, statistics & friends packages, and you want GPU + GPU support, you currently have to depend on and specialize functionality defined in: Adapt, ConstructionBase, Functors, ParameterHandling, and possibly a few others. To my knowledge, none of those draw on each other for their default implementations. This doesn't compose well and forces packages to either take on a lot of dependencies or not be compatible with parts of the ecosystem. And it can result in a lot of boilerplate code.

As a type developer, I'll probably just want to implement a simple, closure-free API like

get_raw_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
get_semantic_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
reconstruct_from(::Type{<:MyType}, ::ResultOfGetRawContents)
reconstruct_from(::Type{<:MyType}, ::ResultOfGetSemanticContents)

The contract here would be

reconstruct_from(typeof(x), get_raw_contents(x) == x
reconstruct_from(typeof(x), get_semantic_contents(x) == x

reconstruct_from should accept contents with a different numeric precision and array types as the original x, of course, if supported by the type.

A package like ConstructionBase would seem a natural place to host such an API.

get_raw_contents and get_semantic_contents could return tuples, NamedTuples, Reals (even all Numbers?) and Arrays as-is, and return a NamedTuple for structs - get_raw_contents could use fieldcount/fieldname/getfieldandget_semantic_contentscould usepropertynames/getproperty`.

A package like Adapt would probably want to use get_raw_contents for it's default implementation, whereas packages like Functors and ParameterHandling would want get_semantic_contents I guess.

rafaqz commented 2 years ago

Yes ConstructionBase.jl is the natural place for this, and sharing these base methods across more packages was the original reason for it to be written. Flatten.jl, Setfield.jl and I think BangBang.jl all needed it. Fixing any inadequacies it has to make it useful in these other packages is surely within scope.

In your schema I guess getrawcontents == ConstuctionsBase.getfields (in the current PR by @jw3126) and get_semantic_contents == ConstructionBase.getproperties ?

Also, Adapt.jl is already kind of redundant. You can replace it with Flatten.jl or Accessors.jl with that PR. I do this in some GPU based packages like DynamicGrids.jl, but the current problem with types stability in base makes this less of an option than it used to be.

oschulz commented 2 years ago

Also, Adapt.jl is already kind of redundant.

I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(

oschulz commented 2 years ago

In your schema I guess getrawcontents == ConstuctionsBase.getfields (in the current PR by @jw3126) and get_semantic_contents == ConstructionBase.getproperties ?

Yes, that was my idea. If there's room for such a lightweight, closure-free API in ConstructionBase I'd be very happy to pitch in!

ToucheSir commented 2 years ago

What do you need from this package that you can't do with ModelParameters?

My understanding of ModelParameters is that model struct types must be able to accept Params as fields. For philosophical and practical reasons, Flux can not require model structs to use framework-defined types for the purpose of parameter tracking. Instead, this information is currently kept out-of-band via functions like functor or trainable.

Now this isn't necessarily set in stone, but I've yet to see a satisfactory solution to the rewrite-the-world-to-work-with-[param wrapper type(s)] problem.


I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(

In defense of Adapt, I think the type stability is more important than the potentially lost flexibility. There's also the deeper question of whether (un)flattening is the right implementation strategy for something like adapt_structure. In https://github.com/FluxML/Functors.jl/pull/27, I tried exploring what would happen if we used a more FP-inspired structural map (think Flatten.modify) as the core primitive. This turned out to have its own set of challenges, but the point is that both the problem and solution space are quite broad!


@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents for the same type?

A concrete motivating example may be found in a NN module system. Getting all parameters of a layer may be handled by get_raw_contents, but we'd also want get_trainable_contents for optimization and get_trivially_serializable_contents for saving layer state.

jw3126 commented 2 years ago

how would you foresee handling multiple different versions of get_semantic_contents for the same type?

IMO multiple get_semantic_contents on the same type are not the job of ConstructionBase. Analogous to Base not providing multiple variants of getproperty on a single type. Instead I think lenses are a great abstraction for this, so you could use Accessors.jl or Setfield.jl

oschulz commented 2 years ago

@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents for the same type?

Thanks!

Multiple different versions of get_semantic_contents for the same type in what respect? Could you give a quick example?

ToucheSir commented 2 years ago

I put a relevant example in the comment above, but the gist is that not all traversals we make over a nested object DAG tree (I'm being cheeky here, but handling "shared" nodes is a big part of the Functors codebase) will want to access the same set of fields. Thus defining a single get_semantic_contents function has limited value for our use cases unless the definition of that function is sufficiently general. But then if it's too general, there doesn't seem to be much of a difference between get_semantic_contents and get_raw_contents, AIUI.

This is why I mentioned StructWalk earlier: by parameterizing the (un)flatten function with the type of traversal being performed (WalkStyle), the traversal code can be decoupled from the code that extracts the contents of each node. Generic fallbacks mean that no convenience is lost for types that return the same contents regardless of the purpose of the traversal. Is it possible to do something similar with lenses?

oschulz commented 2 years ago

but handling "shared" nodes is a big part of the Functors codebase

That would be part of the nested walk scheme relevant to the application though, right, and orthogonal to the "get content of this type" API?

by parameterizing the (un)flatten function with the type of traversal being performed

Ah, I think I get it @ToucheSir . So you mean instead of having get_raw_contents(x) and get_semantic_contents(x) we'll need something like get_contents(x, decomposition_style), with a matching reconstruct_from(T, content, decomposition_style)?

Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?

oschulz commented 2 years ago

With such a proposed closure-free API in ConstructionBase, would ParameterHandling adopt it? And what would be needed (see discussion at end of ConstructionBase.jl#54).

ToucheSir commented 2 years ago

Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?

Yes. The default could be identical to get_semantic_contents(x). I'll pick up the rest on the ConstructionBase thread.