compintell / Tapir.jl

https://compintell.github.io/Tapir.jl/
MIT License
100 stars 3 forks source link

Broken Turing model #241

Closed yebai closed 6 days ago

yebai commented 1 week ago
julia> using Turing, Tapir
Precompiling Turing
  68 dependencies successfully precompiled in 48 seconds. 248 already precompiled.

julia> @model demo_lkj() = x ~ LKJCholesky(2, 1.0)
demo_lkj (generic function with 2 methods)

julia> sample(demo_lkj(), NUTS(;adtype=AutoTapir(false)), 2000)

┌ Warning: Unable to put rule in rule field. Rule should error.
└ @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1363               |  ETA: N/A
MethodInstance is
MethodInstance for Base._string(::Char)

with signature
Tuple{typeof(Base._string), Char}

derived_rule is of type
Tapir.DerivedRule{Tuple{typeof(Base._string), Tuple{Char}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._string), NoFData}, Tapir.CoDual{Tuple{Char}, NoFData}}, Union{}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{NoRData}, Tuple{NoRData, NoRData}}}, Val{true}, Val{2}}

Expected type is
Tapir.DerivedRule{Tuple{typeof(Base._string), Tuple{Char}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._string), NoFData}, Tapir.CoDual{Tuple{Char}, NoFData}}, Tapir.CoDual{String, NoFData}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{NoRData}, Tuple{NoRData, NoRData}}}, Val{true}, Val{2}}

Environment:

(@plato) pkg> st
Status `~/.julia/environments/plato/Project.toml`
  [07d77754] Tapir v0.2.44
  [fce5fe82] Turing v0.34.0 `https://github.com/TuringLang/Turing.jl.git#master`
willtebbutt commented 1 week ago

Full error is

┌ Warning: Unable to put rule in rule field. Rule should error.
└ @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1363
MethodInstance is
MethodInstance for Base._string(::Char)

with signature
Tuple{typeof(Base._string), Char}

derived_rule is of type
Tapir.DerivedRule{Tuple{typeof(Base._string), Tuple{Char}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._string), NoFData}, Tapir.CoDual{Tuple{Char}, NoFData}}, Union{}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{NoRData}, Tuple{NoRData, NoRData}}}, Val{true}, Val{2}}

Expected type is
Tapir.DerivedRule{Tuple{typeof(Base._string), Tuple{Char}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Tapir.CoDual{typeof(Base._string), NoFData}, Tapir.CoDual{Tuple{Char}, NoFData}}, Tapir.CoDual{String, NoFData}}}, MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{NoRData}, Tuple{NoRData, NoRData}}}, Val{true}, Val{2}}

Sampling 100%|████████████████████████████████████████████████████████████████████████| Time: 0:00:45
ERROR: No rrule!! available for foreigncall with primal argument types Tuple{Val{:jl_alloc_string}, Val{Ref{String}}, Tuple{Val{UInt64}}, Val{1}, Val{(:ccall, 0x0e)}, UInt64}. This problem has most likely arisen because there is a ccall somewhere in the function you are trying to differentiate, for which an rrule!! has not been explicitly written.You have three options: write an rrule!! for this foreigncall, write an rrule!! for a Julia function that calls this foreigncall, or re-write your code to avoid this foreigncall entirely. If you believe that this error has arisen for some other reason than the above, or the above does not help you to workaround this problem, please open an issue.
Stacktrace:
  [1] rrule!!(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/rrules/foreigncall.jl:12
  [2] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
  [3] _string
    @ ./strings/substring.jl:230 [inlined]
  [4] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [5] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
  [6] (::Tapir.DerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720
  [7] (::Tapir.LazyDerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1376
  [8] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
  [9] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/abstract_varinfo.jl:854 [inlined]
 [10] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [11] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [12] DerivedRule
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720 [inlined]
 [13] (::Tapir.LazyDerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1380
 [14] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
 [15] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/abstract_varinfo.jl:850 [inlined]
 [16] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [17] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [18] DerivedRule
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720 [inlined]
 [19] (::Tapir.LazyDerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1380
 [20] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
 [21] demo_lkj
    @ ./REPL[4]:1 [inlined]
 [22] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [23] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [24] DerivedRule
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720 [inlined]
 [25] (::Tapir.LazyDerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1380
 [26] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
 [27] evaluate_threadsafe!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:961 [inlined]
 [28] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [29] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [30] DerivedRule
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720 [inlined]
 [31] (::Tapir.LazyDerivedRule{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ Tapir ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:1380
 [32] RRuleZeroWrapper
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:247 [inlined]
 [33] logdensity
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/logdensityfunction.jl:138 [inlined]
 [34] (::Tuple{…})(none::Tapir.CoDual{…}, none::Tapir.CoDual{…}, none::Tapir.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [35] (::MistyClosures.MistyClosure{…})(::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [36] DerivedRule
    @ ~/.julia/packages/Tapir/hBdbd/src/interpreter/s2s_reverse_mode_ad.jl:720 [inlined]
 [37] logdensity_and_gradient(∇l::TapirLogDensityProblemsADExt.TapirGradientLogDensity{…}, x::Vector{…})
    @ TapirLogDensityProblemsADExt ~/.julia/packages/Tapir/hBdbd/ext/TapirLogDensityProblemsADExt.jl:55
 [38] ∂logπ∂θ
    @ ~/.julia/packages/Turing/Bv30p/src/mcmc/hmc.jl:180 [inlined]
 [39] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [40] phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:74 [inlined]
 [41] phasepoint(rng::Random.TaskLocalRNG, θ::Vector{…}, h::AdvancedHMC.Hamiltonian{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:155
 [42] initialstep(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, vi_original::DynamicPPL.TypedVarInfo{…}; initial_params::Nothing, nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/Bv30p/src/mcmc/hmc.jl:184
 [43] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}; initial_params::Nothing, kwargs::@Kwargs{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/sampler.jl:116
 [44] step
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/sampler.jl:99 [inlined]
 [45] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:130 [inlined]
 [46] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [47] (::AbstractMCMC.var"#22#23"{…})()
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:12
 [48] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging ./logging.jl:515
 [49] with_logger(f::Function, logger::LoggingExtras.TeeLogger{Tuple{…}})
    @ Base.CoreLogging ./logging.jl:627
 [50] with_progresslogger(f::Function, _module::Module, logger::Logging.ConsoleLogger)
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:36
 [51] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:11 [inlined]
 [52] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{…})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [53] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/Bv30p/src/mcmc/hmc.jl:123
 [54] sample
    @ ~/.julia/packages/Turing/Bv30p/src/mcmc/hmc.jl:92 [inlined]
 [55] #sample#4
    @ ~/.julia/packages/Turing/Bv30p/src/mcmc/Inference.jl:276 [inlined]
 [56] sample
    @ ~/.julia/packages/Turing/Bv30p/src/mcmc/Inference.jl:267 [inlined]
 [57] #sample#3
    @ ~/.julia/packages/Turing/Bv30p/src/mcmc/Inference.jl:264 [inlined]
 [58] sample(model::DynamicPPL.Model{…}, alg::NUTS{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/Bv30p/src/mcmc/Inference.jl:258
Some type information was truncated. Use `show(err)` to see complete types.

Plainly some printing isn't quite working as intended, but having tried differentiating through the logpdf computation of LKJCholesky, there appear to be problems there. I suspect, therefore, that this is the main thing going wrong here, but we should also fix whatever is going wrong with the string printing.

willtebbutt commented 1 week ago

I've minimised this to the following:

using Distributions, Tapir
function logkernel(d::LKJCholesky, x::Vector{Float64})
    d.d == 1 && return nothing
    return sum(x -> x, Iterators.drop(x, 1))
end
sig = Tuple{typeof(my_logkernel), LKJCholesky{Float64}, Vector{Float64}}
Tapir.build_rrule(Tapir.TapirInterpreter(), sig)

Interestingly, if I don't apply the standard inlining pass to the generated code for the reverse-pass, the problem goes away. This suggests that Tapir.jl has output reasonable code during codegen, but rather something goes wrong as a result of applying an inlining pass. This suggests to me that it's not a problem on my end, but rather a problem in the compiler (albeit I'm not 100% certain about this).

Small tweaks on this seem to remove the problem. I've tried:

  1. changing x -> x with identity,
  2. replacing the argument d with an Int, and putting that in the comparison directly,
  3. removing Iterators.drop.

If you modify this example in any of these ways, the problem goes away. The second of these is probably the closest I've found to something that might give us some insight into what's going on, because it creates a 1-line difference to the optimised IR for logkernel (a getfield call is inserted), but causes everything to fall over. It's very odd.

It's tricky to minimise further, because the problem occurs in automatically generated code on the reverse-pass.

willtebbutt commented 1 week ago

Update: I've figured out what's going on here. PR incoming (probably on Monday now, not today)

willtebbutt commented 6 days ago

Resolved by #242 -- regression tests etc added, so I'm going to close this.