TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

`PlanarLayer` broken in 0.9.9 #204

Closed Red-Portal closed 2 years ago

Red-Portal commented 2 years ago

Hi, the new update for PlanarLayer related to ReverseDiff seems to have broken Zygote. Here's the error code.

ERROR: ArgumentError: The interval [a,b] is not a bracketing interval. |  ETA: 3:54:43
You need f(a) and f(b) to have different signs (f(a) * f(b) < 0).
Consider a different bracket or try fzero(f, c) with an initial guess c.

Stacktrace:
  [1] assert_bracket
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:339 [inlined]
  [2] #init_state#22
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:87 [inlined]
  [3] init_state(::Roots.BisectionExact, F::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#102#103"{Float64, Float64, Float64}, Nothing}, x₀::Float64, x₁::Float64, fx₀::Float64, fx₁::Float64)
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:85
  [4] init_state(M::Roots.BisectionExact, F::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#102#103"{Float64, Float64, Float64}, Nothing}, x::Tuple{Float64, Float64})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:81
  [5] #init#18
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:776 [inlined]
  [6] solve(::Roots.ZeroProblem{Bijectors.var"#102#103"{Float64, Float64, Float64}, Tuple{Float64, Float64}}, ::Roots.BisectionExact, ::Vararg{Any, N} where N; verbose::Bool, kwargs::Base.Iterators.Pairs{Symbol, Roots.NullTracks, Tuple{Symbol}, NamedTuple{(:tracks,), Tuple{Roots.NullTracks}}})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:974
  [7] find_zero(fs::Function, x0::Tuple{Float64, Float64}, method::Roots.Bisection; p::Nothing, tracks::Roots.NullTracks, verbose::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:271
  [8] find_zero
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:258 [inlined]
  [9] #find_zero#17
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:705 [inlined]
 [10] find_zero
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:705 [inlined]
 [11] find_alpha
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/planar_layer.jl:161 [inlined]
 [12] rrule(#unused#::typeof(Bijectors.find_alpha), wt_y::Float64, wt_u_hat::Float64, b::Float64)
    @ Bijectors ~/.julia/packages/Bijectors/du7oP/src/chainrules.jl:2
 [13] rrule
    @ ~/.julia/packages/ChainRulesCore/bxKCw/src/rules.jl:134 [inlined]
 [14] chain_rrule
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/chainrules.jl:182 [inlined]
 [15] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0 [inlined]
 [16] _pullback
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:9 [inlined]
 [17] #1082
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:198 [inlined]
 [18] _broadcast_getindex_evalf
    @ ./broadcast.jl:648 [inlined]
 [19] _broadcast_getindex
    @ ./broadcast.jl:621 [inlined]
 [20] getindex
    @ ./broadcast.jl:575 [inlined]
 [21] copy
    @ ./broadcast.jl:898 [inlined]
 [22] materialize
    @ ./broadcast.jl:883 [inlined]
 [23] _broadcast
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:162 [inlined]
 [24] adjoint
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:198 [inlined]
 [25] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [26] _apply
    @ ./boot.jl:804 [inlined]
 [27] adjoint
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:200 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [29] _pullback
    @ ./broadcast.jl:1315 [inlined]
 [30] _pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/planar_layer.jl:116 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::Inverse{PlanarLayer{Vector{Float64}, Vector{Float64}}, 1}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [32] _pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/interface.jl:102 [inlined]
 [33] _pullback(::Zygote.Context, ::typeof(forward), ::Inverse{PlanarLayer{Vector{Float64}, Vector{Float64}}, 1}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [34] macro expansion
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:0 [inlined]
 [35] _pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:222 [inlined]
 [36] _pullback(::Zygote.Context, ::typeof(forward), ::Composed{NTuple{30, Inverse{PlanarLayer{Vector{Float64}, Vector{Float64}}, 1}}, 1}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [37] _pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/transformed_distribution.jl:108 [inlined]
 [38] _pullback(::Zygote.Context, ::typeof(Distributions._logpdf), ::MultivariateTransformed{DiagNormal, Composed{NTuple{30, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [39] _pullback
    @ ~/.julia/packages/Distributions/1313k/src/multivariates.jl:201 [inlined]
 [40] #1072
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:185 [inlined]
 [41] _broadcast_getindex_evalf
    @ ./broadcast.jl:648 [inlined]
 [42] _broadcast_getindex
    @ ./broadcast.jl:621 [inlined]
 [43] getindex
    @ ./broadcast.jl:575 [inlined]
 [44] copy
    @ ./broadcast.jl:922 [inlined]
 [45] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#1072#1076"{Zygote.Context, typeof(logpdf)}, Tuple{Base.RefValue{MultivariateTransformed{DiagNormal, Composed{NTuple{30, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}}})
    @ Base.Broadcast ./broadcast.jl:883
 [46] _broadcast(::Zygote.var"#1072#1076"{Zygote.Context, typeof(logpdf)}, ::Base.RefValue{MultivariateTransformed{DiagNormal, Composed{NTuple{30, PlanarLayer{Vector{Float64}, Vector{Float64}}}, 1}}}, ::Vararg{Any, N} where N)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:162
 [47] adjoint
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:185 [inlined]

Works perfectly fine on 0.9.8

devmotion commented 2 years ago

Seems you hit a special case where the initial bracket in Bijectors.find_alpha is empty, ie where lower and upper bound are equal. In this case the root finding algorithm in Roots throws an error since it expects a proper bracket. Apparently, NonlinearSolve which was used in Bijectors 0.9.8 did not care (the initial bracket is the same in all more or less recent versions).

devmotion commented 2 years ago

I know there exists the same problem with Roots in HypothesisTests, so I guess it might be useful to fix the problem upstream in Roots and handle empty brackets gracefully.

devmotion commented 2 years ago

I'm more and more convinced that the problem should be handled in Bijectors. Empty brackets are really not an appropriate input for bracketing methods (therefore e.g. scalars are disallowed as well) and, maybe even more importantly, it is much more difficult to assess if an empty bracket is appropriate or not in Roots. In Bijectors we know that also the empty bracket is correct, modulo some floating point issues, and hence the lower (or equivalently upper) bound can be returned; in Roots we would have to evaluate the function at the given point but due to numerical inaccuracies probably it won't be exactly zero and therefore one has to check that it's approximately zero which depends on the (implicit or explicit) tolerances and could also hide problems with the empty bracket.

I'll make a PR that fixes the issue in Bijectors.

Red-Portal commented 2 years ago

@devmotion Thanks for the good work. Cheers.

Red-Portal commented 2 years ago

Hi @devmotion this bug still pop-up, but it triggers even when the domain (lower, upper) does not contain a solution. Wonder how it was handled in NonlinearSolve. I recommend this issue to be reopened.

devmotion commented 2 years ago

What do you mean by "when the domain (lower, upper) does not contain a solution"? The bracket should either be empty (lower == upper) or contain a solution. In the latter case it is guaranteed mathematically that the solution is between lower and upper and both have different signs. Maybe there are floating point issues in some edge cases? Do you get exactly the same error message as stated above? In this case probably we have to check f(lower) and f(upper) and ensure that their signs are different.

Red-Portal commented 2 years ago

Here's what I saw. lower=19.82 and upper=22.00.

RROR: ArgumentError: The interval [a,b] is not a bracketing interval.
You need f(a) and f(b) to have different signs (f(a) * f(b) < 0).
Consider a different bracket or try fzero(f, c) with an initial guess c.

Stacktrace:
  [1] assert_bracket
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:339 [inlined]
  [2] #init_state#22
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:87 [inlined]
  [3] init_state(::Roots.BisectionExact, F::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#305#306"{Float64, Float64, Float64}, Nothing}, x₀::Float64, x₁::Float64, fx₀::Float64, fx₁::Float64)
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:85
  [4] init_state(M::Roots.BisectionExact, F::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#305#306"{Float64, Float64, Float64}, Nothing}, x::Tuple{Float64, Float64})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:81
  [5] #init#18
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:776 [inlined]
  [6] solve(::Roots.ZeroProblem{Bijectors.var"#305#306"{Float64, Float64, Float64}, Tuple{Float64, Float64}}, ::Roots.BisectionExact, ::Vararg{Any, N} where N; verbose::Bool, kwargs::Base.Iterators.Pairs{Symbol, Roots.NullTracks, Tuple{Symbol}, NamedTuple{(:tracks,), Tuple{Roots.NullTracks}}})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:974
  [7] find_zero(fs::Function, x0::Tuple{Float64, Float64}, method::Roots.Bisection; p::Nothing, tracks::Roots.NullTracks, verbose::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Roots ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:271
  [8] find_zero
    @ ~/.julia/packages/Roots/H7pXT/src/bracketing.jl:258 [inlined]
  [9] #find_zero#17
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:705 [inlined]
 [10] find_zero
    @ ~/.julia/packages/Roots/H7pXT/src/find_zero.jl:705 [inlined]
 [11] find_alpha
    @ ./REPL[2]:15 [inlined]
 [12] rrule(#unused#::typeof(Bijectors.find_alpha), wt_y::Float64, wt_u_hat::Float64, b::Float64)
    @ Bijectors ~/.julia/packages/Bijectors/EELoe/src/chainrules.jl:2
 [13] rrule
    @ ~/.julia/packages/ChainRulesCore/Y1Mee/src/rules.jl:134 [inlined]
 [14] chain_rrule
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/chainrules.jl:191 [I

As far as I know, this error is also thrown when the interval itself is valid but no solution exists. That is, the sign of f(a)*f(b) is positive. Thus I presume there is no solution? Didn't see a NaN anywhere.

Red-Portal commented 2 years ago

@devmotion just found the reason.

I executed

@eval Bijectors begin
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
    # Compute the initial bracket (see above).
    abs_wt_u_hat = abs(wt_u_hat)
    lower = float(wt_y - abs_wt_u_hat)
    upper = float(wt_y + abs_wt_u_hat)

    # Handle empty brackets (https://github.com/TuringLang/Bijectors.jl/issues/204)
    if lower == upper
        return lower
    end
    f(α) = α + wt_u_hat * tanh(α + b) - wt_y
    @info "" upper=upper lower=lower fa=f(upper) fb=f(lower)

    # Solve the root-finding problem
    α0 = Roots.find_zero(f, (lower, upper))
    return α0
end
end

and just got

┌ Info: 
│   upper = 32.51258720807707
│   lower = 30.853197121333704
│   fa = -3.552713678800501e-15
└   fb = -1.6593900867433717

So it was a numerical stability issue. Not sure how to fix this without breaking anything though

devmotion commented 2 years ago

Mathematically there has to exist a solution. And the bracketing interval is mathematically guaranteed to contain it, so my only hypothesis right now is that there's some floating point issue.

Red-Portal commented 2 years ago

@devmotion Did you see my latest response? As a side note, could you take a look at the #203 ? I checked that sparse arrays don't have an issue. I think there's something wrong with NoTangent and sparse arrays.

devmotion commented 2 years ago

Oh, just saw your new comment now. So it seems it's actually a numerical problem here - I wonder if the upper bound is just a tiny bit off due to floating point errors or if fa is slightly incorrect.

Red-Portal commented 2 years ago

I guess 1e-15 is small enough. Could we get away by performing a boundary solution check with epsilons?

devmotion commented 2 years ago

BTW I don't think this is a Roots-specific issue but maybe just bad luck? At least NonlinearSolve performs the same check: https://github.com/SciML/NonlinearSolve.jl/blob/fc46e99774e207e64ed8e0ea43d9d4fa3ad8a699/src/solve.jl#L84-L86

Red-Portal commented 2 years ago

oh, That's surprising.

devmotion commented 2 years ago

Could we get away by performing a boundary solution check with epsilons?

I would prefer a cleaner solution that does not have to rely on some implicitly set tolerances and guarantees that the problem can't reappear (I am worried that such a check might still miss some cases).

Can you rerun your code with abs_wt_u_hat = 2 * abs(wt_u_hat) (ignore the misleading name :stuck_out_tongue:) and check if you still observe the problem? This doubles the length of the initial interval and hence cause one additional step in the root finding algorithm in the worst case but hopefully fixes the problem.

Red-Portal commented 2 years ago

@devmotion Yup that did the trick. Works well on my test case.

devmotion commented 2 years ago

@Red-Portal I just released the fix in Bijectors 0.9.11.