TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Zygote error differentiating `Coupling` #203

Open Red-Portal opened 2 years ago

Red-Portal commented 2 years ago

Hi, Coupling currently has an issue with differentiation. Here's a reproducible example.

using Bijectors
using Flux
using ProgressMeter
using StatsBase
using StatsPlots
using Turing
using Zygote

function main()
    n_iter         = 3000
    lr                = 1e-3
    n_samples = 10
    n_batch     = 4
    data          = randn(2, n_samples)

    base_dist = MvNormal(zeros(2), ones(2))
    layers    = Bijectors.Coupling(θ -> Bijectors.Shift(θ) ∘ Bijectors.Scale(θ), 2)

    flow = transformed(base_dist, layers)
    pars = Flux.params(flow)
    prog = ProgressMeter.Progress(n_iter)
    opt  = ADAM(lr)

    for i = 1:n_iter
        batch_idx = sample(1:n_samples, n_batch, replace=false)
        batch     = view(data, :, batch_idx)
        loss, back = Zygote.pullback(pars) do
            -mean(logpdf.(Ref(flow), eachcol(batch)))
        end
        grad       = back(one(loss))

        Flux.Optimise.update!(opt, pars, grad)
        ProgressMeter.next!(prog; showvalues=[(:loss, loss),]) 
    end
end
julia> main()
ERROR: MethodError: Cannot `convert` an object of type ChainRulesCore.ZeroTangent to an object of type ChainRulesCore.NoTangent
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
Stacktrace:
  [1] fill!(dest::Vector{ChainRulesCore.NoTangent}, x::ChainRulesCore.ZeroTangent)
    @ Base ./array.jl:333
  [2] _map_notzeropres!(f::typeof(Zygote.accum), fillvalue::ChainRulesCore.ZeroTangent, C::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, B::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:345
  [3] _noshapecheck_map(f::typeof(Zygote.accum), A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, Bs::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:166
  [4] _shapecheckbc(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::Vararg{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, N} where N)
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1026
  [5] _copy(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1016
  [6] copy
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1012 [inlined]
  [7] materialize
    @ ./broadcast.jl:883 [inlined]
  [8] accum(x::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ys::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:25
  [9] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [10] accum(x::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}, y::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [11] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [12] accum(x::NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}, y::NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [13] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [14] accum(x::NamedTuple{(:orig,), Tuple{NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}}, y::NamedTuple{(:orig,), Tuple{NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [15] getindex
    @ ./tuple.jl:29 [inlined]
 [16] gradindex
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/reverse.jl:12 [inlined]
 [17] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/interface.jl:102 [inlined]
 [18] (::typeof(∂(forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [19] macro expansion
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:0 [inlined]
 [20] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:222 [inlined]
 [21] (::typeof(∂(forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/transformed_distribution.jl:108 [inlined]
 [23] (::typeof(∂(_logpdf)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/Distributions/1313k/src/multivariates.jl:201 [inlined]
 [25] (::typeof(∂(logpdf)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [26] #1073
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:188 [inlined]
(PartialSMC) pkg> status
      Status `~/Projects/PartialSMC/Project.toml`
  [76274a88] Bijectors v0.9.9
  [31c24e10] Distributions v0.25.20
  [ced4e74d] DistributionsAD v0.6.31
  [634d3b9d] DrWatson v2.6.0
  [587475ba] Flux v0.12.7
  [5ab0869b] KernelDensity v0.6.3
  [872c559c] NNlib v0.7.29
  [90014a1f] PDMats v0.11.1
  [91a5bcdd] Plots v1.22.6
  [92933f4c] ProgressMeter v1.7.1
  [d330b81b] PyPlot v2.10.0
  [74087812] Random123 v1.4.2
  [e6cf234a] RandomNumbers v1.5.3
  [276daf66] SpecialFunctions v1.7.0
  [2913bbd2] StatsBase v0.33.11
  [4c63d2b9] StatsFuns v0.9.12
  [f3b207a7] StatsPlots v0.14.28
  [fce5fe82] Turing v0.18.0
  [e88e6eb3] Zygote v0.6.28
  [9a3f8284] Random
  [10745b16] Statistics

Seems like an issue with the chain rule of sparse arrays?

devmotion commented 2 years ago

Can you try different older Zygote versions and check if it was introduced in some version? Unfortunately, in my experience the Zygote-ChainRules integration is quite unstable and it is not uncommon that new releases break previously working code while fixing something else.

Red-Portal commented 2 years ago

@devmotion Unfortunately, I can't wind up to older versions of Zygote due to an issue with Flux needing OneElement which seems to not exist anymore. Any idea what might have caused the problem? Or should I knock on the door of ChainRules? It does seem to be a Bijector specific problem though.

devmotion commented 2 years ago

No clue, I saw so many mysterious bugs that I became tired of hunting them down and gave up on trying to understand all of them :smile: Based on the stacktrace I assume it is a more general issue with sparse arrays (it seems we end up a ZeroTangent where we need a NoTangent) but of course it is triggered by some code in Bijectors and it would be good to have some minimal example without Bijectors if one wants to demonstrate that is a general problem.

Red-Portal commented 2 years ago

I guess I'll have to try ReverseDiff for the meantime.