TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
199 stars 32 forks source link

Zygote is broken for `Stacked` bijectors #252

Closed Red-Portal closed 1 month ago

Red-Portal commented 1 year ago

Here's a MWE

using Distributions, Bijectors, Zygote

dists = [Beta(), InverseGamma()]
ranges = [1:2, 3:3]
bs = Bijectors.bijector.(dists)
binvs = inverse.(bs)
stacked = Bijectors.Stacked(binvs, ranges)
display(stacked(randn(3)))

Zygote.gradient(x -> Bijectors.with_logabsdet_jacobian(stacked, x)[2], randn(3))

This results in

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
  ...
Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:235 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::Base.var"#reduce##kw", ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:9
  [5] _pullback
    @ ./reducedim.jl:359 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#766", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::Bijectors.var"#28#29"{Vector{Float64}}, ::typeof(vcat), ::Vector{Any}, ::Vector{UnitRange{Int64}})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
  [8] adjoint
    @ ~/.julia/packages/Zygote/TSj5C/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./reducedim.jl:359 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::Base.var"#mapreduce##kw", ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(mapreduce), ::Bijectors.var"#28#29"{Vector{Float64}}, ::typeof(vcat), ::Vector{Any}, ::Vector{UnitRange{Int64}})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/.julia/packages/Bijectors/vKGbw/src/bijectors/stacked.jl:158 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::typeof(with_logabsdet_jacobian), ::Stacked{Vector{Any}, Vector{UnitRange{Int64}}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
 [14] _pullback
    @ ./REPL[179]:1 [inlined]
 [15] _pullback(ctx::Zygote.Context{false}, f::var"#1337#1338", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
 [16] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:44
 [17] pullback
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:42 [inlined]
 [18] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:96
 [19] top-level scope
    @ REPL[179]:1

Note that if ranges is set as

ranges = [1, 2]

everything works. Not sure if this is intended behavior?

Red-Portal commented 1 year ago

Hmm.. seems like a mapreduce diffrule is broken... again?

@devmotion I think you're probably the expert on differentiating mapreduce. Do you have any clue about what's going on?

torfjelde commented 1 year ago

A hotfix is just:

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
    N = length(sb.bs)
    yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
    logjac = sum(linit)
    # HACK: Return early to avoid `mapreduce` over empty-collection which Zygote.jl doesn't like.
    N == 1 && return (yinit, logjac)

    ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
        y, l = with_logabsdet_jacobian(b, x[r])
        logjac += sum(l)
        y
    end
    return (ys, logjac)
end
Red-Portal commented 1 year ago

@torfjelde Hmm.. I don't think the hotfix fixes the MWE. In fact the MWE results in N=2 so it doesn't hit the exit condition.

Red-Portal commented 1 year ago

For anyone that hits this and needs to get this fixed ASAP, here's a quick and dirty solution (at the cost of twice the computation time...):

@eval Bijectors begin
function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
    ys = mapreduce(vcat, sb.bs, sb.ranges) do b, r
        y, _ = with_logabsdet_jacobian(b, x[r])
        y
    end
    logjac = mapreduce(+, sb.bs, sb.ranges) do b, r
        _, l = with_logabsdet_jacobian(b, x[r])
        first(l)
    end
    return (ys, logjac)
end
end
torfjelde commented 1 year ago

Aah sorry, I saw the [1, 2] and immediately thought it was another issue.

It seems the line https://github.com/FluxML/Zygote.jl/blob/dc16b2e35cdfa9dda62f636e95271490548cc574/src/compiler/chainrules.jl#L235 returns nothing, and which in turn means we're hitting https://github.com/JuliaDiff/ChainRulesCore.jl/blob/79ba4ef03afdf5715b6fa0294e5accfe4e95c79b/src/rules.jl#L138 for some reason.

In particular, the traceback

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
  ...
Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
    @ Main ./REPL[23]:3
  [3] macro expansion
    @ ./compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::Base.var"#reduce##kw", ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
    @ Zygote ./compiler/interface2.jl:9
  [5] _pullback
    @ ./reducedim.jl:359 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#766", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#3#4"{Vector{Float64}}, ::typeof(vcat), ::Vector{Any}, ::Vector{UnitRange{Int64}})
    @ Zygote ./compiler/interface2.jl:0

seems to indicate that we're somehow triggering the rule for reduce.

Currently trying to figure out what's going on here.