JuliaGaussianProcesses / ParameterHandling.jl

Foundational tooling for handling collections of parameters in models
MIT License
72 stars 11 forks source link

Overloading-AD-Friendly Unflatten #39

Open willtebbutt opened 3 years ago

willtebbutt commented 3 years ago

Addresses #27 . @paschermayr could you confirm that it resolves your problem? I've run your example locally, but want to make sure that it does what you expect.

This is breaking. @rofinn can you see any problem doing this? It specifically changes code that you wrote. The type constraints are now only applied in the flatten bit -- it's assumed that in unflatten you could reasonably want to use numbers that aren't of the exact same type is the thing that was requested in flatten, e.g. so that you can propagate a Vector{<:Dual} through unflatten when using ForwardDiff.jl.

codecov[bot] commented 3 years ago

Codecov Report

Merging #39 (b923d81) into master (21e6ff7) will increase coverage by 0.08%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #39      +/-   ##
==========================================
+ Coverage   96.49%   96.57%   +0.08%     
==========================================
  Files           4        4              
  Lines         171      175       +4     
==========================================
+ Hits          165      169       +4     
  Misses          6        6              
Impacted Files Coverage Δ
src/flatten.jl 98.14% <100.00%> (+0.07%) :arrow_up:
src/parameters.jl 97.50% <100.00%> (+0.06%) :arrow_up:
src/test_utils.jl 92.50% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 21e6ff7...b923d81. Read the comment docs.

paschermayr commented 3 years ago

Worked for me - I just added a few comments but none of them should influence the merge. Thank you for your work!

rofinn commented 3 years ago

My 2 cents:

  1. I agree that these type constraints were perhaps a bit too restrictive, due to instabilities in Zygote that were hard to resolve.
  2. I haven't looked at this code in almost a year, but the general idea of loosening the types to allow things like duals to work with unflatten seems reasonable.
  3. Would it be possible to introduce a strict=false keyword that certains types of performance code could use to enforce symmetry/stability in both flatten and unflatten? I worry that the performance benefits we saw dependend on that, though I really don't remember anymore. I'd also be fine with it not being the default since reducing the precision is a bit of a niche use-case.
  4. Would you mind including a few benchmarks in the case where you want to use reduced precision?
paschermayr commented 3 years ago

My 2 cents:

  1. Would it be possible to introduce a strict=false keyword that certains types of performance code could use to enforce symmetry/stability in both flatten and unflatten? I worry that the performance benefits we saw dependend on that, though I really don't remember anymore. I'd also be fine with it not being the default since reducing the precision is a bit of a niche use-case.

Having a strict=true for unflatten in the non-AD case might be a good idea, we could use views in this case. One would have to adjust map(flatten, x) with a map(x) do flatten(x, strict) end block to have the same performance as before.

paschermayr commented 3 years ago

@willtebbutt : I think I managed to implement a method that allows us to keep initial types and lets us work with AD. I uploaded a version here: https://github.com/paschermayr/Shared-Code/blob/master/parameterhandling.jl

I am not sure if this is ideal for ParameterHandling.jl, as it is optimized for the unflatten part (while flattening, buffers for unflatten are created), but I think you can adjust this easily otherwise. I haven't tested it for all the different Parameter types in ParameterHandling.jl. Example:

using BenchmarkTools
nt = (a = 1, b = [2, 3], c = Float32(4.), d = 5.)
typeof(nt.c) #Float32
nt_vec, unflat = flatten(Float16, true, nt) #Vector{Float16} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float32
@btime $unflat($nt_vec) #20.942 ns (0 allocations: 0 bytes)

#For AD no type conversion:
nt_vec, unflat = flatten(Float64, false, nt) #Vector{Float64} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float64
willtebbutt commented 3 years ago

Note: I've not forgotten about this PR. I'm currently swamped with PhD work, and will return to it when I get some time. @paschermayr how urgent is this for you? Are you happy to work on this branch / with your work-around for now, or do you need a release?

paschermayr commented 3 years ago

@willtebbutt Not urgent at all, happy to work with what I have. Thank you in any case!

theogf commented 2 years ago

Is there a lot left to do? I somehow need this feature :stuck_out_tongue_closed_eyes:

willtebbutt commented 2 years ago

To be honest, I'm not entirely sure. It's dropped off my radar, somewhat, and I'm not going to have time to properly look at it until I've submitted.

simsurace commented 2 years ago

This would be very useful if it were merged, as using ForwardDiff or ReverseDiff instead of Zygote can lead to a massive improvement in gradient evaluation: a quick benchmark with a sparse variational GP produced these numbers:

@btime loss($θ_flat) # 188.774 μs (123 allocations: 79.52 KiB)
@btime ForwardDiff.gradient($loss, $θ_flat) # 2.389 ms (1231 allocations: 5.80 MiB)
@btime ReverseDiff.gradient($loss, $θ_flat) # 8.220 ms (308089 allocations: 13.17 MiB)
@btime Zygote.gradient($loss, $θ_flat) # 36.134 ms (421748 allocations: 15.25 MiB)
paschermayr commented 2 years ago

Since this PR started, I have created another package, because my needs were slightly different than the ParameterHandling.jl case, https://github.com/paschermayr/ModelWrappers.jl . I managed to incorporate all possible cases (Flatten/unflatten performant / AD compatible / Taking into account Integers) by using a separate struct as argument in the flatten function that has all kinds of configurations. The exact specifications can be seen here: https://github.com/paschermayr/ModelWrappers.jl/blob/main/src/Core/constraints/flatten/flatten.jl

Note that this package is optimized for the case when unflatten is performed more often (which can often be performed with 0 allocations by creating buffers while flattening), and is quite dependency heavy as I integrated some AD use cases, but maybe a similar solution could be implemented in ParameterHandling to take care of most corner cases without performance loss.

st-- commented 2 years ago

@paschermayr that looks great! thanks for sharing. I was thinking about how to assign priors to parameters ... glad someone already started working it out: )

simsurace commented 2 years ago

ModelWrappers.jl looks great, but does it provide the same functionality? E.g. ParameterHandling.positive_definite is something that I use a lot. IMHO this PR should still be merged. I wonder what is missing, since all tests are passing. Is there an important test case missing?

paschermayr commented 2 years ago

ModelWrappers.jl looks great, but does it provide the same functionality? E.g. ParameterHandling.positive_definite is something that I use a lot. IMHO this PR should still be merged. I wonder what is missing, since all tests are passing. Is there an important test case missing?

A similar functionality, but the focus is different and my package is much more dependency heavy at the moment. I just linked it to show one possible solution so that AutoDiff can be applied both for the flatten and unflatten case (this was the reason I created ModelWrappers in the first place). I would also like this to be merged if possible, ideally so that Autodiff works in both directions, but if it only works for flatten for now, that would be fine too.

As for the other question, any Bijector for a Matrixdistribution that satisfies your constraints should work here - I also implemented a CorrelationMatrix and CovarianceMatrix transformer separately.

using ModelWrappers

mat = [1.0 .2 ; .2 3.0]
constraint = CovarianceMatrix()
model = ModelWrapper((Σ = Param(mat, constraint), ))

mat_flat = flatten(model) #Vector{Float64} with 3 elements 1.00, 0.200, 3.00
mat_unflat = unflatten(model, mat_flat) #(Σ = [1.0 0.2; 0.2 3.0],)

θᵤ = unconstrain_flatten(model) #Vector{Float64} with 3 elements 0.00 0.200 0.543…
unflatten_constrain(model, θᵤ) #(Σ = [1.0 0.2; 0.2 3.0],)
jariji commented 8 months ago

This PR does what I'm looking for. A small example that fails before the PR and works after:

using ParameterHandling, Optim
let
    f((;x)) = x^2
    θ₀ = (;x = 4.0)
    flat_θ, unflatten = ParameterHandling.value_flatten(θ₀)
    opt = optimize(f∘unflatten, flat_θ, LBFGS(); autodiff=:forward)
    @test only(Optim.minimizer(opt)) ≈ 0.0
end