EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Enzyme segfaults on Turing model #650

Closed sethaxen closed 1 year ago

sethaxen commented 1 year ago

I just checked again the model in https://github.com/TuringLang/Turing.jl/pull/1887#issue-1389132949 on that branch, and it once again (after being fixed in #457) segfaults after emitting warnings. Below is the complete code sample:

using Turing
using Enzyme

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)

The full stacktrace can be found at https://gist.github.com/sethaxen/5666e1c6c9d8194e0370c60eb70de49e#file-log-txt

julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, tigerlake)
  Threads: 8 on 8 virtual cores
Environment:
  JULIA_NUM_THREADS = auto

julia> using Pkg; Pkg.status()
Status `~/Downloads/enzyme_turing_test/Project.toml`
  [7da242da] Enzyme v0.11.0-dev `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [f151be2c] EnzymeCore v0.2.1 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`
  [fce5fe82] Turing v0.24.1 `https://github.com/TuringLang/Turing.jl.git#dw/enzyme`

It also fails on the latest release of Enzyme.

wsmoses commented 1 year ago

@sethaxen can you extract this out so we can see the function being passed to autodiff?

As is it's hard to see what function is being differentiated, in order to debug.

wsmoses commented 1 year ago

@sethaxen this runs correctly for me on Enzyme#main and Julia 1.9.

First run (for compile), and second below

┌ Info: Found initial step size
└   ϵ = 0.4
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:01:05
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 6:1:15
Number of chains  = 1
Samples per chain = 10
Wall duration     = 66.48 seconds
Compute duration  = 66.48 seconds
parameters        = m, s
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           m    0.3281    0.9148    0.2893    10.0000        NaN    1.0216        0.1504
           s    3.3405    4.6931    1.4841     9.3304     6.6667    1.3839        0.1403

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           m   -1.1945   -0.1198    0.5916    1.1116    1.1116
           s    0.3560    0.3560    2.4453    3.5200   13.3314

julia> sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)
┌ Info: Found initial step size
└   ϵ = 1.6
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 6:1:15
Number of chains  = 1
Samples per chain = 10
Wall duration     = 0.01 seconds
Compute duration  = 0.01 seconds
parameters        = m, s
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           m    0.1278    0.4991    0.1578    10.0000    10.0000    1.0588     1666.6667
           s    1.8968    1.7990    0.5689    10.0000    10.0000    1.1069     1666.6667

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           m   -0.7156   -0.1037    0.0939    0.3510    0.8085
           s    0.8492    0.9252    0.9904    1.7534    5.6334
sethaxen commented 1 year ago

@sethaxen can you extract this out so we can see the function being passed to autodiff?

As is it's hard to see what function is being differentiated, in order to debug.

Sure, here's the version that contains the call to autodiff:

using Turing, Enzyme
using Turing.LogDensityProblems

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

mod = model()
sampler = DynamicPPL.Sampler(NUTS())
vi = DynamicPPL.VarInfo(mod)
vi = DynamicPPL.link!!(vi, sampler, mod)
ℓ = Turing.LogDensityFunction(vi, mod, sampler, DynamicPPL.DefaultContext())
x = vi[sampler]  # Vector{Float64}
∂ℓ_∂x = zero(x)
Enzyme.autodiff(
    Reverse,
    LogDensityProblems.logdensity,
    Enzyme.Active,
    Enzyme.Const(ℓ),
    Enzyme.Duplicated(x, ∂ℓ_∂x),
)

@sethaxen this runs correctly for me on Enzyme#main and Julia 1.9.

Strange, because this also segfaults for me on Julia 1.9.

wsmoses commented 1 year ago

Odd, okay, would you be able to simplify the above? E.g. simplify and/or inline as much as possible?

sethaxen commented 1 year ago

No, sorry, I am not familiar with the inner workings of this code and have no time right now.

ViralBShah commented 1 year ago

@yebai Is this the same as the issue you mentioned in #658?

yebai commented 1 year ago

No -- this issue is already fixed by using an immutable internal data structure (SimpleVarInfo if you are interested). This issue can be closed now.

PS. Here is a working example of Turing using Enzyme:

julia> using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme

julia> @model demo() = x ~ Normal()
demo (generic function with 2 methods)

julia> model = demo()
Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext}(demo, NamedTuple(), NamedTuple(), DefaultContext())

julia> f = DynamicPPL.LogDensityFunction(model, SimpleVarInfo((x = 1.0, )));

julia> fwithgrad = ADgradient(:Enzyme, f);

julia> LogDensityProblems.logdensity_and_gradient(fwithgrad, [1.0])
(-1.4189385332046727, [-1.0])
sethaxen commented 1 year ago

No -- this issue is already fixed by using an immutable internal data structure (SimpleVarInfo if you are interested). This issue can be closed now.

No, this is not fixed using SimpleVarInfo. We found if we used s ~ InverseGamma(), s ~ Gamma(), or s ~ truncated(Normal(); lower=0), we got the same segfaults. If we remove s and x entirely, then everything is fine, even with the usual VarInfo.

devmotion commented 1 year ago

As another data point, the example in https://github.com/EnzymeAD/Enzyme.jl/issues/650#issuecomment-1455228871 segfaults for me with Enzyme#main, EnzymeCore#main and Enzyme_jll#main on Julia 1.9 rc1.

Using DynamicPPL.SimpleVarInfo, i.e., replacing the line vi = DynamicPPL.VarInfo(mod) with vi = DynamicPPL.SimpleVarInfo(mod), fixes the segfault but yields an error:

julia> Enzyme.autodiff(
           Reverse,
           LogDensityProblems.logdensity,
           Enzyme.Active,
           Enzyme.Const(ℓ),
           Enzyme.Duplicated(x, ∂ℓ_∂x),
       )
ERROR: Return type inferred to be Union{}. Giving up.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] #s479#163
   @ ~/.julia/packages/Enzyme/SUstD/src/compiler.jl:8160 [inlined]
 [3] var"#s479#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ShadowInit::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
   @ Enzyme.Compiler ./none:0
 [4] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
   @ Core ./boot.jl:602
 [5] thunk(f::typeof(LogDensityProblems.logdensity), df::Nothing, ::Type{Duplicated{Union{}}}, tt::Type{Tuple{Const{LogDensityFunction{DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}}}, Duplicated{Vector{Float64}}}}, ::Val{Enzyme.API.DEM_ReverseModeGradient}, ::Val{1}, ::Val{(false, false, false)}, ::Val{false}, ::Val{true})
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/SUstD/src/compiler.jl:8218
 [6] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(LogDensityProblems.logdensity), ::Type{Active}, ::Const{LogDensityFunction{DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}}}, ::Vararg{Any})
   @ Enzyme ~/.julia/packages/Enzyme/SUstD/src/Enzyme.jl:185
 [7] top-level scope
   @ REPL[10]:1

Surprisingly it seems the return type of the logdensity function can't be inferred even though we work with a simple NamedTuple here:

julia> @code_warntype LogDensityProblems.logdensity(ℓ, x)
MethodInstance for LogDensityProblems.logdensity(::LogDensityFunction{DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}}, ::Vector{Float64})
  from logdensity(f::LogDensityFunction, θ::AbstractVector) @ DynamicPPL ~/.julia/packages/DynamicPPL/UFajj/src/logdensityfunction.jl:92
Arguments
  #self#::Core.Const(LogDensityProblems.logdensity)
  f::LogDensityFunction{DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}}
  θ::Vector{Float64}
Locals
  vi_new::DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}
Body::Union{}
1 ─ %1 = Base.getproperty(f, :varinfo)::DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}
│   %2 = Base.getproperty(f, :context)::DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}
│        (vi_new = DynamicPPL.unflatten(%1, %2, θ))
│   %4 = Base.getproperty(f, :model)::Core.Const(DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}(model, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext()))
│   %5 = vi_new::DynamicPPL.SimpleVarInfo{NamedTuple{(:m, :s, :x), Tuple{Float64, Float64, Float64}}, Float64, DynamicPPL.DynamicTransformation}
│   %6 = Base.getproperty(f, :context)::DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}
│        DynamicPPL.evaluate!!(%4, %5, %6)
│        Core.Const(:(DynamicPPL.last(%7)))
│        Core.Const(:(DynamicPPL.getlogp(%8)))
└──      Core.Const(:(return %9))
vchuravy commented 1 year ago

Union{} means "gurantueed to error."

What does execution the function normally yield?

devmotion commented 1 year ago

It should return a Float64 but after getting some sleep it's clear to me why the SimpleVarInfo version fails. The reduced example in https://github.com/EnzymeAD/Enzyme.jl/issues/650#issuecomment-1455228871 is not correct - it introduces a mix of ForwardDiff.Dual and sampling in the log density evaluation. The log density function should be constructed by

ℓ = Turing.LogDensityFunction(mod, vi, DynamicPPL.DefaultContext())

instead. Then the logdensity can be evaluated correctly and Enzyme can compute the gradient (when using SimpleVarInfo; tested with Enzyme#main, EnzymeCore#main, Enzyme_jll#main on Julia 1.9 rc1):

julia> autodiff(
           ReverseWithPrimal,
           LogDensityProblems.logdensity,
           ℓ,
           Duplicated(x, ∂ℓ_∂x),
       )
((nothing, nothing), -4.2993310577423145)

julia> ∂ℓ_∂x
3-element Vector{Float64}:
  0.4264673357364165
 -0.6593350799670791
  0.5109041418560486

julia> LogDensityProblems.logdensity(ℓ, x)
-4.2993310577423145

julia> ForwardDiff.gradient(x -> logjoint(mod, DynamicPPL.SimpleVarInfo((m = x[1], s = x[2], x = x[3]), zero(eltype(x)), DynamicPPL.DynamicTransformation())), x)
3-element Vector{Float64}:
  0.4264673357364165
 -0.6593350799670789
  0.5109041418560486
wsmoses commented 1 year ago

@devmotion well unfortunately as I cannot reproduce the segfault on main, you're going to have to minimize it (and hopefully therefore allow me to reproduce it), in order to start any investigation and/or fix.

devmotion commented 1 year ago

I don't understand how it's possible that you can successfully run the example in the OP. For me, the same happens as @sethaxen described above: When I run

using Turing, Enzyme

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)

I get a lot of warnings and then Julia segfaults. I used the latest version of Enzyme, Julia, and the Turing branch with Enzyme support:

julia> versioninfo()
Julia Version 1.9.0-rc1
Commit 3b2e0d8fbc1 (2023-03-07 07:51 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 3 on 8 virtual cores
Environment:
  JULIA_NUM_THREADS = 3
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_EDITOR = code
  JULIA_PKG_SERVER = https://pumasai.juliahub.com

(enzyme) pkg> st
Status `~/sources/enzyme/Project.toml`
  [7da242da] Enzyme v0.11.0-dev `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [f151be2c] EnzymeCore v0.2.1 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`
  [fce5fe82] Turing v0.24.1 `https://github.com/TuringLang/Turing.jl.git#dw/enzyme`
  [7cc45869] Enzyme_jll v0.0.51+0 `https://github.com/JuliaBinaryWrappers/Enzyme_jll.jl.git#main`

Does your setup in some way differ from ours?

devmotion commented 1 year ago

I found it :tada: I had a final idea, based on the differences between sampling/computing derivatives with a single thread and with multiple threads (or the respective methods for these cases) that I had observed in #659. And indeed, when I erase the JULIA_NUM_THREADS environment variable and start Julia single-threaded, then sampling succeeds. However, it still emits all these warnings which possibly require some fixes or at least should not show up in non-debug mode, I think. Interestingly, the warnings/stacktrace points to GPUCompiler even though I don't run anything on the GPU?! For instance,

┌ Warning: TypeAnalysisDepthLimit
│ {[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Integer, [0,8]:Integer, [0,9]:Integer, [0,10]:Integer, [0,11]:Integer, [0,12]:Integer, [0,13]:Integer, [0,14]:Integer, [0,15]:Integer, [0,16]:Integer, [0,17]:Integer, [0,18]:Integer, [0,19]:Integer, [0,20]:Integer, [0,21]:Integer, [0,22]:Integer, [0,23]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Integer, [0,34]:Integer, [0,35]:Integer, [0,36]:Integer, [0,37]:Integer, [0,38]:Integer, [0,39]:Integer, [0,40]:Integer, [8]:Pointer, [8,0]:Pointer, [8,0,0]:Pointer, [8,8]:Integer, [8,9]:Integer, [8,10]:Integer, [8,11]:Integer, [8,12]:Integer, [8,13]:Integer, [8,14]:Integer, [8,15]:Integer, [8,16]:Integer, [8,17]:Integer, [8,18]:Integer, [8,19]:Integer, [8,20]:Integer, [8,21]:Integer, [8,22]:Integer, [8,23]:Integer, [8,24]:Integer, [8,25]:Integer, [8,26]:Integer, [8,27]:Integer, [8,28]:Integer, [8,29]:Integer, [8,30]:Integer, [8,31]:Integer, [8,32]:Integer, [8,33]:Integer, [8,34]:Integer, [8,35]:Integer, [8,36]:Integer, [8,37]:Integer, [8,38]:Integer, [8,39]:Integer, [8,40]:Integer, [16]:Pointer, [16,0]:Pointer, [16,0,0]:Pointer, [16,0,0,0]:Pointer, [16,0,0,0,0]:Pointer, [16,0,0,0,0,0]:Integer, [16,0,0,0,0,1]:Integer, [16,0,0,0,0,2]:Integer, [16,0,0,0,0,3]:Integer, [16,0,0,0,0,4]:Integer, [16,0,0,0,0,5]:Integer, [16,0,0,0,0,6]:Integer, [16,0,0,0,0,7]:Integer, [16,0,0,0,8]:Integer, [16,0,0,0,9]:Integer, [16,0,0,0,10]:Integer, [16,0,0,0,11]:Integer, [16,0,0,0,12]:Integer, [16,0,0,0,13]:Integer, [16,0,0,0,14]:Integer, [16,0,0,0,15]:Integer, [16,0,0,0,16]:Integer, [16,0,0,0,17]:Integer, [16,0,0,0,18]:Integer, [16,0,0,0,19]:Integer, [16,0,0,0,20]:Integer, [16,0,0,0,21]:Integer, [16,0,0,0,22]:Integer, [16,0,0,0,23]:Integer, [16,0,0,0,24]:Integer, [16,0,0,0,25]:Integer, [16,0,0,0,26]:Integer, [16,0,0,0,27]:Integer, [16,0,0,0,28]:Integer, [16,0,0,0,29]:Integer, [16,0,0,0,30]:Integer, [16,0,0,0,31]:Integer, [16,0,0,0,32]:Integer, [16,0,0,0,33]:Integer, [16,0,0,0,34]:Integer, [16,0,0,0,35]:Integer, [16,0,0,0,36]:Integer, [16,0,0,0,37]:Integer, [16,0,0,0,38]:Integer, [16,0,0,0,39]:Integer, [16,0,0,0,40]:Integer, [16,0,0,8]:Integer, [16,0,0,9]:Integer, [16,0,0,10]:Integer, [16,0,0,11]:Integer, [16,0,0,12]:Integer, [16,0,0,13]:Integer, [16,0,0,14]:Integer, [16,0,0,15]:Integer, [16,0,0,16]:Integer, [16,0,0,17]:Integer, [16,0,0,18]:Integer, [16,0,0,19]:Integer, [16,0,0,20]:Integer, [16,0,0,21]:Integer, [16,0,0,22]:Integer, [16,0,0,23]:Integer, [16,8]:Integer, [16,9]:Integer, [16,10]:Integer, [16,11]:Integer, [16,12]:Integer, [16,13]:Integer, [16,14]:Integer, [16,15]:Integer, [16,16]:Integer, [16,17]:Integer, [16,18]:Integer, [16,19]:Integer, [16,20]:Integer, [16,21]:Integer, [16,22]:Integer, [16,23]:Integer, [16,24]:Integer, [16,25]:Integer, [16,26]:Integer, [16,27]:Integer, [16,28]:Integer, [16,29]:Integer, [16,30]:Integer, [16,31]:Integer, [16,32]:Integer, [16,33]:Integer, [16,34]:Integer, [16,35]:Integer, [16,36]:Integer, [16,37]:Integer, [16,38]:Integer, [16,39]:Integer, [16,40]:Integer, [24]:Integer, [25]:Integer, [26]:Integer, [27]:Integer, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer, [32]:Integer, [33]:Integer, [34]:Integer, [35]:Integer, [36]:Integer, [37]:Integer, [38]:Integer, [39]:Integer, [40]:Integer, [41]:Integer, [42]:Integer, [43]:Integer, [44]:Integer, [45]:Integer, [46]:Integer, [47]:Integer, [48]:Integer, [49]:Integer, [50]:Integer, [51]:Integer, [52]:Integer, [53]:Integer, [54]:Integer, [55]:Integer, [56]:Integer, [57]:Integer, [58]:Integer, [59]:Integer, [60]:Integer, [61]:Integer, [62]:Integer, [63]:Integer}
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/S3TWf/src/utils.jl:50
not handling more than 6 pointer lookups deep dt:{[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Integer, [0,8]:Integer, [0,9]:Integer, [0,10]:Integer, [0,11]:Integer, [0,12]:Integer, [0,13]:Integer, [0,14]:Integer, [0,15]:Integer, [0,16]:Integer, [0,17]:Integer, [0,18]:Integer, [0,19]:Integer, [0,20]:Integer, [0,21]:Integer, [0,22]:Integer, [0,23]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Integer, [0,34]:Integer, [0,35]:Integer, [0,36]:Integer, [0,37]:Integer, [0,38]:Integer, [0,39]:Integer, [0,40]:Integer, [8]:Pointer, [8,0]:Pointer, [8,0,0]:Pointer, [8,8]:Integer, [8,9]:Integer, [8,10]:Integer, [8,11]:Integer, [8,12]:Integer, [8,13]:Integer, [8,14]:Integer, [8,15]:Integer, [8,16]:Integer, [8,17]:Integer, [8,18]:Integer, [8,19]:Integer, [8,20]:Integer, [8,21]:Integer, [8,22]:Integer, [8,23]:Integer, [8,24]:Integer, [8,25]:Integer, [8,26]:Integer, [8,27]:Integer, [8,28]:Integer, [8,29]:Integer, [8,30]:Integer, [8,31]:Integer, [8,32]:Integer, [8,33]:Integer, [8,34]:Integer, [8,35]:Integer, [8,36]:Integer, [8,37]:Integer, [8,38]:Integer, [8,39]:Integer, [8,40]:Integer, [16]:Pointer, [16,0]:Pointer, [16,0,0]:Pointer, [16,0,0,0]:Pointer, [16,0,0,0,0]:Pointer, [16,0,0,0,0,0]:Integer, [16,0,0,0,0,1]:Integer, [16,0,0,0,0,2]:Integer, [16,0,0,0,0,3]:Integer, [16,0,0,0,0,4]:Integer, [16,0,0,0,0,5]:Integer, [16,0,0,0,0,6]:Integer, [16,0,0,0,0,7]:Integer, [16,0,0,0,8]:Integer, [16,0,0,0,9]:Integer, [16,0,0,0,10]:Integer, [16,0,0,0,11]:Integer, [16,0,0,0,12]:Integer, [16,0,0,0,13]:Integer, [16,0,0,0,14]:Integer, [16,0,0,0,15]:Integer, [16,0,0,0,16]:Integer, [16,0,0,0,17]:Integer, [16,0,0,0,18]:Integer, [16,0,0,0,19]:Integer, [16,0,0,0,20]:Integer, [16,0,0,0,21]:Integer, [16,0,0,0,22]:Integer, [16,0,0,0,23]:Integer, [16,0,0,0,24]:Integer, [16,0,0,0,25]:Integer, [16,0,0,0,26]:Integer, [16,0,0,0,27]:Integer, [16,0,0,0,28]:Integer, [16,0,0,0,29]:Integer, [16,0,0,0,30]:Integer, [16,0,0,0,31]:Integer, [16,0,0,0,32]:Integer, [16,0,0,0,33]:Integer, [16,0,0,0,34]:Integer, [16,0,0,0,35]:Integer, [16,0,0,0,36]:Integer, [16,0,0,0,37]:Integer, [16,0,0,0,38]:Integer, [16,0,0,0,39]:Integer, [16,0,0,0,40]:Integer, [16,0,0,8]:Integer, [16,0,0,9]:Integer, [16,0,0,10]:Integer, [16,0,0,11]:Integer, [16,0,0,12]:Integer, [16,0,0,13]:Integer, [16,0,0,14]:Integer, [16,0,0,15]:Integer, [16,0,0,16]:Integer, [16,0,0,17]:Integer, [16,0,0,18]:Integer, [16,0,0,19]:Integer, [16,0,0,20]:Integer, [16,0,0,21]:Integer, [16,0,0,22]:Integer, [16,0,0,23]:Integer, [16,8]:Integer, [16,9]:Integer, [16,10]:Integer, [16,11]:Integer, [16,12]:Integer, [16,13]:Integer, [16,14]:Integer, [16,15]:Integer, [16,16]:Integer, [16,17]:Integer, [16,18]:Integer, [16,19]:Integer, [16,20]:Integer, [16,21]:Integer, [16,22]:Integer, [16,23]:Integer, [16,24]:Integer, [16,25]:Integer, [16,26]:Integer, [16,27]:Integer, [16,28]:Integer, [16,29]:Integer, [16,30]:Integer, [16,31]:Integer, [16,32]:Integer, [16,33]:Integer, [16,34]:Integer, [16,35]:Integer, [16,36]:Integer, [16,37]:Integer, [16,38]:Integer, [16,39]:Integer, [16,40]:Integer, [24]:Integer, [25]:Integer, [26]:Integer, [27]:Integer, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer, [32]:Integer, [33]:Integer, [34]:Integer, [35]:Integer, [36]:Integer, [37]:Integer, [38]:Integer, [39]:Integer, [40]:Integer, [41]:Integer, [42]:Integer, [43]:Integer, [44]:Integer, [45]:Integer, [46]:Integer, [47]:Integer, [48]:Integer, [49]:Integer, [50]:Integer, [51]:Integer, [52]:Integer, [53]:Integer, [54]:Integer, [55]:Integer, [56]:Integer, [57]:Integer, [58]:Integer, [59]:Integer, [60]:Integer, [61]:Integer, [62]:Integer, [63]:Integer} only(56): 
┌ Warning: TypeAnalysisDepthLimit
│   store {} addrspace(10)* %.fca.0.0.1.7.extract, {} addrspace(10)* addrspace(10)* %.fca.0.0.1.7.gep, align 8, !dbg !19
│ {[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Pointer, [0,0,0,0]:Integer, [0,0,8]:Integer, [0,0,9]:Integer, [0,0,10]:Integer, [0,0,11]:Integer, [0,0,12]:Integer, [0,0,13]:Integer, [0,0,14]:Integer, [0,0,15]:Integer, [0,0,16]:Integer, [0,0,17]:Integer, [0,0,18]:Integer, [0,0,19]:Integer, [0,0,20]:Integer, [0,0,21]:Integer, [0,0,22]:Integer, [0,0,23]:Integer, [0,0,24]:Integer, [0,0,25]:Integer, [0,0,26]:Integer, [0,0,27]:Integer, [0,0,28]:Integer, [0,0,29]:Integer, [0,0,30]:Integer, [0,0,31]:Integer, [0,0,32]:Integer, [0,0,33]:Integer, [0,0,34]:Integer, [0,0,35]:Integer, [0,0,36]:Integer, [0,0,37]:Integer, [0,0,38]:Integer, [0,0,39]:Integer, [0,0,40]:Integer, [0,8]:Pointer, [0,8,0]:Pointer, [0,8,0,0]:Pointer, [0,8,8]:Integer, [0,8,9]:Integer, [0,8,10]:Integer, [0,8,11]:Integer, [0,8,12]:Integer, [0,8,13]:Integer, [0,8,14]:Integer, [0,8,15]:Integer, [0,8,16]:Integer, [0,8,17]:Integer, [0,8,18]:Integer, [0,8,19]:Integer, [0,8,20]:Integer, [0,8,21]:Integer, [0,8,22]:Integer, [0,8,23]:Integer, [0,8,24]:Integer, [0,8,25]:Integer, [0,8,26]:Integer, [0,8,27]:Integer, [0,8,28]:Integer, [0,8,29]:Integer, [0,8,30]:Integer, [0,8,31]:Integer, [0,8,32]:Integer, [0,8,33]:Integer, [0,8,34]:Integer, [0,8,35]:Integer, [0,8,36]:Integer, [0,8,37]:Integer, [0,8,38]:Integer, [0,8,39]:Integer, [0,8,40]:Integer, [0,16]:Pointer, [0,16,0]:Pointer, [0,16,0,0]:Pointer, [0,16,0,0,0]:Pointer, [0,16,0,0,0,0]:Pointer, [0,16,0,0,0,8]:Integer, [0,16,0,0,0,9]:Integer, [0,16,0,0,0,10]:Integer, [0,16,0,0,0,11]:Integer, [0,16,0,0,0,12]:Integer, [0,16,0,0,0,13]:Integer, [0,16,0,0,0,14]:Integer, [0,16,0,0,0,15]:Integer, [0,16,0,0,0,16]:Integer, [0,16,0,0,0,17]:Integer, [0,16,0,0,0,18]:Integer, [0,16,0,0,0,19]:Integer, [0,16,0,0,0,20]:Integer, [0,16,0,0,0,21]:Integer, [0,16,0,0,0,22]:Integer, [0,16,0,0,0,23]:Integer, [0,16,0,0,0,24]:Integer, [0,16,0,0,0,25]:Integer, [0,16,0,0,0,26]:Integer, [0,16,0,0,0,27]:Integer, [0,16,0,0,0,28]:Integer, [0,16,0,0,0,29]:Integer, [0,16,0,0,0,30]:Integer, [0,16,0,0,0,31]:Integer, [0,16,0,0,0,32]:Integer, [0,16,0,0,0,33]:Integer, [0,16,0,0,0,34]:Integer, [0,16,0,0,0,35]:Integer, [0,16,0,0,0,36]:Integer, [0,16,0,0,0,37]:Integer, [0,16,0,0,0,38]:Integer, [0,16,0,0,0,39]:Integer, [0,16,0,0,0,40]:Integer, [0,16,0,0,8]:Integer, [0,16,0,0,9]:Integer, [0,16,0,0,10]:Integer, [0,16,0,0,11]:Integer, [0,16,0,0,12]:Integer, [0,16,0,0,13]:Integer, [0,16,0,0,14]:Integer, [0,16,0,0,15]:Integer, [0,16,0,0,16]:Integer, [0,16,0,0,17]:Integer, [0,16,0,0,18]:Integer, [0,16,0,0,19]:Integer, [0,16,0,0,20]:Integer, [0,16,0,0,21]:Integer, [0,16,0,0,22]:Integer, [0,16,0,0,23]:Integer, [0,16,8]:Integer, [0,16,9]:Integer, [0,16,10]:Integer, [0,16,11]:Integer, [0,16,12]:Integer, [0,16,13]:Integer, [0,16,14]:Integer, [0,16,15]:Integer, [0,16,16]:Integer, [0,16,17]:Integer, [0,16,18]:Integer, [0,16,19]:Integer, [0,16,20]:Integer, [0,16,21]:Integer, [0,16,22]:Integer, [0,16,23]:Integer, [0,16,24]:Integer, [0,16,25]:Integer, [0,16,26]:Integer, [0,16,27]:Integer, [0,16,28]:Integer, [0,16,29]:Integer, [0,16,30]:Integer, [0,16,31]:Integer, [0,16,32]:Integer, [0,16,33]:Integer, [0,16,34]:Integer, [0,16,35]:Integer, [0,16,36]:Integer, [0,16,37]:Integer, [0,16,38]:Integer, [0,16,39]:Integer, [0,16,40]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Integer, [0,34]:Integer, [0,35]:Integer, [0,36]:Integer, [0,37]:Integer, [0,38]:Integer, [0,39]:Integer, [0,40]:Integer, [0,41]:Integer, [0,42]:Integer, [0,43]:Integer, [0,44]:Integer, [0,45]:Integer, [0,46]:Integer, [0,47]:Integer, [0,48]:Integer, [0,49]:Integer, [0,50]:Integer, [0,51]:Integer, [0,52]:Integer, [0,53]:Integer, [0,54]:Integer, [0,55]:Integer, [0,56]:Integer, [0,57]:Integer, [0,58]:Integer, [0,59]:Integer, [0,60]:Integer, [0,61]:Integer, [0,62]:Integer, [0,63]:Integer}
│ 
│ Stacktrace:
│  [1] Fix1
│    @ ./operators.jl:0
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/S3TWf/src/utils.jl:50

So then apart from these warnings the remaining question, as in #659 is, why we see the segfaults/different behaviour in the multithreaded case. The only difference is that in the multi-threaded case the variable structure is put into a wrapper to ensure that @threads for ... loops in the models accumulate the density correctly. Otherwise, the underlying logic and calls are the same in both cases. https://github.com/EnzymeAD/Enzyme.jl/issues/659#issuecomment-1464653374 contains some additional explanations and links.

wsmoses commented 1 year ago

Great find!

Unfortunately, will still need a minimal example to be able to resolve, if you can similarly try to simplify!

wsmoses commented 1 year ago

I reduced it down to the following, which still needs to be reduced a lot more to be able to debug.

@devmotion if you can assist, you'd be a lot faster than me since I have no idea waht any of these libraries/etc are xD

using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme, LinearAlgebra
using Turing
using Enzyme
using Turing.AbstractMCMC

using AdvancedHMC

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

using Random

mod = model() | (; x=0.5)
alg = Turing.NUTS{Turing.EnzymeAD}()
spl = Sampler(alg, mod)

vi = DynamicPPL.default_varinfo(Random.GLOBAL_RNG, mod, spl)

vi = link!!(vi, spl, mod)

    # Extract parameters.
    theta = vi[spl]

    # Create a Hamiltonian.
    metricT = Turing.Inference.getmetricT(spl.alg)
    metric = metricT(length(theta))
    ℓ = LogDensityProblemsAD.ADgradient(
        Turing.LogDensityFunction(vi, mod, spl, DynamicPPL.DefaultContext())
    )
    logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
    ∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x)
    hamiltonian = AdvancedHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)

    # Compute phase point z.
    # r = rand(Random.GLOBAL_RNG, metricT, size(metric)...)
    # r ./= 
    # r ./= metric.sqrtM⁻¹
                               # AdvancedHMC.rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic)
                               # AdvancedHMC.
                               # rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic)

                               AdvancedHMC.phasepoint(hamiltonian, theta, rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic))

# AbstractMCMC.step(Random.GLOBAL_RNG, mod, alg)
# mymcmcsample(Random.GLOBAL_RNG, mod, alg, 10)
# sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)
devmotion commented 1 year ago

I'm happy to help but unfortunately it might take a few days before I find time for some more debugging.

wsmoses commented 1 year ago

@devmotion any luck?

sethaxen commented 1 year ago

@wsmoses here's a smaller example:

using Enzyme
using Turing.LogDensityProblems
using Turing.Distributions
using Turing: DynamicPPL, NUTS

DynamicPPL.@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

mod = model()
sampler = DynamicPPL.Sampler(NUTS())
vi = DynamicPPL.VarInfo(mod)
vi = DynamicPPL.link!!(vi, sampler, mod)
ℓ = DynamicPPL.LogDensityFunction(mod, vi, DynamicPPL.DefaultContext())

x = vi[sampler]  # Vector{Float64}
∂ℓ_∂x = zero(x)
LogDensityProblems.logdensity(ℓ, x)  # works
Enzyme.autodiff(
    Reverse,
    LogDensityProblems.logdensity,
    Const(ℓ),
    Duplicated(x, ∂ℓ_∂x),
)

On Enzyme v0.10, this segfaults for me regardless of whether JULIA_NUM_THREADS is set or not. On Enzyme v0.11, it prints out a bunch of warnings (same as before), and I get the following error:

ERROR: MethodError: no method matching callconv!(::Ptr{LLVM.API.LLVMOpaqueValue}, ::UInt32)

Closest candidates are:
  callconv!(::Union{LLVM.CallBrInst, LLVM.CallInst, LLVM.InvokeInst}, ::Any)
   @ LLVM ~/.julia/packages/LLVM/TLGyi/src/core/instructions.jl:155
  callconv!(::LLVM.Function, ::Any)
   @ LLVM ~/.julia/packages/LLVM/TLGyi/src/core/function.jl:27

Stacktrace:
  [1] jl_array_ptr_copy_fwd(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:4873
  [2] jl_array_ptr_copy_augfwd(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:4892
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/EncRR/src/api.jl:124
  [4] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:6680
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:7921
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:8434
  [7] _thunk
    @ ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:8431 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:8469 [inlined]
  [9] #s286#175
    @ ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:8527 [inlined]
 [10] var"#s286#175"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] thunk
    @ ~/.julia/packages/Enzyme/EncRR/src/compiler.jl:8486 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/EncRR/src/Enzyme.jl:199 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/EncRR/src/Enzyme.jl:228 [inlined]
 [15] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(LogDensityProblems.logdensity), ::Const{DynamicPPL.LogDensityFunction{DynamicPPL.TypedVarInfo{NamedTuple{(:m, :s, :x), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:s, Setfield.IdentityLens}, Int64}, Vector{InverseGamma{Float64}}, Vector{AbstractPPL.VarName{:s, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Model{typeof(model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.DefaultContext}}, ::Duplicated{Vector{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/EncRR/src/Enzyme.jl:214
 [16] top-level scope
    @ REPL[33]:1

Note this is now using the release version of Turing and not the branch that glues it and Enzyme (which is not up-to-date with Enzyme v0.11 compat)

julia> using Pkg; Pkg.status()
Status `/tmp/jl_eLfrOK/Project.toml`
  [7da242da] Enzyme v0.11.0
  [fce5fe82] Turing v0.24.3

julia> versioninfo()
Julia Version 1.9.0-rc2
Commit 72aec423c2a (2023-04-01 10:41 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 1 on 8 virtual cores
Environment:
  JULIA_CMDSTAN_HOME = /home/sethaxen/software/cmdstan/2.30.1/
  JULIA_EDITOR = code
wsmoses commented 1 year ago

I've fixed the 0.11 error you saw on main just now. Your test now has the previous behavior of working fine single threaded, but segfaulting multi threaded.

Unfortunately, this means additional minimization is required.

devmotion commented 1 year ago

@devmotion any luck?

No, I had to postpone it. Will return to it probably in ~ 2 weeks.

wsmoses commented 1 year ago

Should be solved by https://github.com/EnzymeAD/Enzyme.jl/pull/772 please reopen if it persists.

sethaxen commented 1 year ago

For me the example in https://github.com/EnzymeAD/Enzyme.jl/issues/650#issuecomment-1508673707 still segfaults (even without JULIA_NUM_THREADS set)

Edit: I don't have permissions to reopen.

wsmoses commented 1 year ago

@sethaxen Can you make a minimal reproducer out of that comment?

wsmoses commented 1 year ago

It appeared to work on my system post fix, unfortunately.

sethaxen commented 1 year ago

@sethaxen Can you make a minimal reproducer out of that comment?

I'll try but I'm also not very familiar with DynamicPPL's internals.

wsmoses commented 1 year ago

A GC error is currently preventing us from cutting a new release (which has numerous fixes). Any progress on minimizing this will help us try to find and fix the issue (and subsequently release).

So far we've failed to reproduce the CI GC failure locally =/

wsmoses commented 1 year ago

A significant GC related fix has now landed on main. Retry to see if resolved?

unfortunately again I could not reproduce your issue even without that fix

devmotion commented 1 year ago

unfortunately again I could not reproduce your issue even without that fix

Very strange. On my side Julia still segfaults when running the example in https://github.com/EnzymeAD/Enzyme.jl/issues/650#issuecomment-1508673707 (after filling up the terminal with low-level output). I used Julia 1.9.1, the Turing#dw/enzyme branch, and Enzyme#main + Enzyme_jll#main.

devmotion commented 1 year ago

And in the same way, the example in the OP (https://github.com/EnzymeAD/Enzyme.jl/issues/650#issue-1610288893) still emits warnings and segfaults.

wsmoses commented 1 year ago

To confirm can you show Julia version, what commit of enzyme you are on and the result of st?

wsmoses commented 1 year ago

Also can you upload the output log?

devmotion commented 1 year ago

To confirm can you show Julia version, what commit of enzyme you are on and the result of st?

(jl_hZ755n) pkg> st
Status `/tmp/jl_hZ755n/Project.toml`
  [7da242da] Enzyme v0.11.1 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [fce5fe82] Turing v0.26.2 `https://github.com/TuringLang/Turing.jl.git#dw/enzyme`
  [7cc45869] Enzyme_jll v0.0.70+0 `https://github.com/JuliaBinaryWrappers/Enzyme_jll.jl.git#main`

julia> versioninfo()
Julia Version 1.9.1
Commit 147bdf428cd (2023-06-07 08:27 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 8 on 8 virtual cores
Environment:
  JULIA_PKG_USE_CLI_GIT = true

Also can you upload the output log?

Hmm, I'll try to figure out if I can save the output in some way - it seems it's too much for the scrollback buffer of my terminal. The last parts showing up in the display are

[0,19]:Integer, [0,20]:Integer, [0,21]:Integer, [0,22]:Integer, [0,23]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Inpace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, i64, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*,   %.unpack5 = extractvalue { { i8*, i8*, i64, i64 }, {} addrspace(10)*, {} addrspace(10)*, { i64, i1 }, { i64, i1 }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, i1, {} addrspace(10)*, i1, i1*, i1*, {} addrspace(10)* addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, i64, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*,   %".fca.0.0.0.0.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.0.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.1.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.1.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.2.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.2.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.3.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.3.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.4.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.4.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.5.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.5.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.6.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.6.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.0.7.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.0.7.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.0.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.0.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.1.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.1.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.2.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.2.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.3.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.3.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.4.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.4.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.5.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.5.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.6.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.6.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.0.1.7.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i  %.fca.0.0.1.7.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 0, i64   %".fca.0.1.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i32 1  %.fca.0.1.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 1, !dbg !93  %".fca.0.2.gep'ipg" = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc", i64 0, i32 0, i32 2  %.fca.0.2.gep = getelementptr { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } }, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %12, i64 0, i32 0, i32 2, !dbg !93  %tapeArg26 = extractvalue { { i8*, i8*, i64, i64 }, {} addrspace(10)*, {} addrspace(10)*, { i64, i1 }, { i64, i1 }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, i1, {} addrspace(10)*, i1, i1*, i1*, {} addrspace(10)* addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, i64, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*,  call fastcc void @diffejulia__copyto_impl__3639({} addrspace(10)* align 16 %21, {} addrspace(10)* align 16 %22, {} addrspace(10)* align 16 %17, {} addrspace(10)* align 16 %18, i64 signext %20, { i8*, i8*, i64, i64 } %tapeArg8), !dbg  call fastcc void @diffejulia___cat_offset1__3648({} addrspace(10)* align 16 %17, {} addrspace(10)* align 16 %18, [1 x i64] addrspace(11)* nocapture readonly align 8 %19, {} addrspace(10)* align 16 %15, {} addrspace(10)* align 16 %"'  %"'ipc25_unwrap" = addrspacecast { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(10)* %"'ipc" to { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(11)*, !dbg   %tapeArg26_unwrap = extractvalue { { i8*, i8*, i64, i64 }, {} addrspace(10)*, {} addrspace(10)*, { i64, i1 }, { i64, i1 }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, {} addrspace(10)* addrspace(10)*, i1, {} addrspace(10)*, i1, i1*, i1*, {} addrspace(10)* addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, i64, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspac  call fastcc void @diffejulia__all_3637({ { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(10)* } } addrspace(11)* nocapture readonly align 8 %_unwrap, { { [2 x [8 x {} addrspace(10)*]], {} addrspace(10)*, {} addrspace(

but that seems incomplete and is missing a lot that showed up but is somehow not preserved by my terminal.

wsmoses commented 1 year ago

Can you put your test in a file, then pipe the output to a file?

Something like julia myfile.jl &> out.txt

Also can you just use the version of Enzyme_jll that corresponds to the Enzyme.jl you use (they must have the correct corresponding versions and setting the main of a jll does not build the underlying binary from its source at the main branch).

devmotion commented 1 year ago

I uploaded the Project+Manifest, the scripts, and their output: https://gist.github.com/devmotion/02c94f3ad0d24e5cfdb1f42dbdef33aa

wsmoses commented 1 year ago

Looking at the log, this is not a segfault, but an assertion error.

 pp:   %_replacementE = phi [1 x i64] , !dbg !128 of   %26 = call fastcc [1 x i64] @julia___cat_offset1__3279({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %25, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %13) #109, !dbg !149
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:8028: void GradientUtils::eraseFictiousPHIs(): Assertion `pp->getNumUses() == 0' failed.

Nevertheless, investigating

devmotion commented 1 year ago

Ah, indeed, I should have inspected the logs 😅 Unfortunately, my terminal could not cope with the output and this line did not show up - and I just assumed that Julia crashed because of a segfault.

wsmoses commented 1 year ago

This code triggers:

using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme, LinearAlgebra
using Turing
using Enzyme
using Turing.AbstractMCMC

using AdvancedHMC

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end

using Random

mod = model() | (; x=0.5)
alg = Turing.NUTS{Turing.EnzymeAD}()
spl = Sampler(alg, mod)

vi = DynamicPPL.default_varinfo(Random.GLOBAL_RNG, mod, spl)

vi = link!!(vi, spl, mod)

    # Extract parameters.
    theta = vi[spl]

    # Create a Hamiltonian.
    metricT = Turing.Inference.getmetricT(spl.alg)
    metric = metricT(length(theta))
    ℓ = LogDensityProblemsAD.ADgradient(
        Turing.LogDensityFunction(vi, mod, spl, DynamicPPL.DefaultContext())
    )
    logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
    ∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x)
    hamiltonian = AdvancedHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)

    # Compute phase point z.
    # r = rand(Random.GLOBAL_RNG, metricT, size(metric)...)
    # r ./= 
    # r ./= metric.sqrtM⁻¹
                               # AdvancedHMC.rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic)
                               # AdvancedHMC.
                               # rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic)

                               AdvancedHMC.phasepoint(hamiltonian, theta, rand(Random.GLOBAL_RNG, metric, hamiltonian.kinetic))

# AbstractMCMC.step(Random.GLOBAL_RNG, mod, alg)
# mymcmcsample(Random.GLOBAL_RNG, mod, alg, 10)
# sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)
wsmoses commented 1 year ago

The particular bug appears to come from istrans if you know where that might be?

wsmoses commented 1 year ago
MethodInstance for DynamicPPL.istrans(::TypedVarInfo{NamedTuple{(:m, :s), Tuple{DynamicPPL.Metadata{Dict{VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{VarName{:s, Setfield.IdentityLens}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64})
wsmoses commented 1 year ago
using Distributions, DynamicPPL
#, LogDensityProblems
#using Turing
using Enzyme
# using Turing.AbstractMCMC
using Setfield

# using AdvancedHMC

Enzyme.autodiff_thunk(Enzyme.ReverseSplitWithPrimal,
                      Const{typeof(DynamicPPL.istrans)},
                      Const,
                      Duplicated{TypedVarInfo{NamedTuple{(:m, :s),
                                    Tuple{DynamicPPL.Metadata{Dict{VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}},
                                          DynamicPPL.Metadata{Dict{VarName{:s, Setfield.IdentityLens}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
                                         }}, Float64}
                     })
wsmoses commented 1 year ago
using Distributions, DynamicPPL
#, LogDensityProblems
#using Turing
using Enzyme
# using Turing.AbstractMCMC
using Setfield

# using AdvancedHMC

function bad(vi::AbstractVarInfo)
    return keys(vi)
end

Enzyme.autodiff_thunk(Enzyme.ReverseSplitWithPrimal,
                      Const{typeof(bad)},
                      Const,
                      Duplicated{TypedVarInfo{NamedTuple{(:m, :s),
                                    Tuple{DynamicPPL.Metadata{Dict{VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}},
                                          DynamicPPL.Metadata{Dict{VarName{:s, Setfield.IdentityLens}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
                                         }}, Float64}
                     })
wsmoses commented 1 year ago
using Distributions, DynamicPPL
#, LogDensityProblems
#using Turing
using Enzyme
# using Turing.AbstractMCMC
using Setfield

struct MVarInfo{Tmeta} <: DynamicPPL.AbstractVarInfo
    metadata::Tmeta
end

const MTypedVarInfo = MVarInfo{<:NamedTuple}

@generated function bad(vi::MTypedVarInfo{<:NamedTuple{names}}) where {names}
    expr = Expr(:call)
    push!(expr.args, :vcat)

    for n in names
        push!(expr.args, :(vi.metadata.$n.vns))
    end

    return expr
end

Enzyme.autodiff_thunk(Enzyme.ReverseSplitWithPrimal,
                      Const{typeof(bad)},
                      Const,
                      Duplicated{MTypedVarInfo{NamedTuple{(:m,),
                                    Tuple{DynamicPPL.Metadata{Dict{VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
                                         }}}
                     })
wsmoses commented 1 year ago

Fixed allocation oddity, back to:

using Distributions, DynamicPPL
#, LogDensityProblems
#using Turing
using Enzyme
# using Turing.AbstractMCMC
using Setfield

# using AdvancedHMC

function bad(vi::AbstractVarInfo)
    return keys(vi)
end

Enzyme.autodiff_thunk(Enzyme.ReverseSplitWithPrimal,
                      Const{typeof(bad)},
                      Const,
                      Duplicated{TypedVarInfo{NamedTuple{(:m, :s),
                                    Tuple{DynamicPPL.Metadata{Dict{VarName{:m, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}},
                                          DynamicPPL.Metadata{Dict{VarName{:s, Setfield.IdentityLens}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
                                         }}, Float64}
                     })
wsmoses commented 1 year ago
using Enzyme

# Enzyme.API.printunnecessary!(true)
# Enzyme.API.printdiffuse!(true)
# Enzyme.API.printall!(true)

function bad(vi)
    return vcat(vi.m, vi.s)
end

struct MVarName{sym}
end

Enzyme.autodiff_thunk(Enzyme.ReverseSplitWithPrimal,
                      Const{typeof(bad)},
                      Const,
                      Duplicated{NamedTuple{(:m, :s),
                                    Tuple{Vector{MVarName{:m}},
                                          Vector{MVarName{:s}}
                                         }}
                     })
wsmoses commented 1 year ago

https://fwd.gymni.ch/3sTs3B

wsmoses commented 1 year ago

https://fwd.gymni.ch/SINV5U

wsmoses commented 1 year ago

Okay @sethaxen @devmotion the latest push to main, at least for me locally, fixes all of the above and runs fine. Note that it does presently require enabling runtime activity (so if you run it first without it will throw an error saying you need to enable for this).

Am going to close for now, please reopen if it persists.

sethaxen commented 1 year ago

Indeed, I can confirm that the examples in this issue now work for me on main. Thanks, @wsmoses for all your help!

wsmoses commented 1 year ago

Nice, I've just released all these fixes (and more) as v0.11.2, so feel free to test out.