EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
428 stars 59 forks source link

Illegal type analysis - BigFloat #1621

Closed mhauru closed 1 day ago

mhauru commented 2 weeks ago
module MWE

import DynamicPPL
import AbstractPPL
import Accessors
using Distributions: MvNormal

using Enzyme

#Enzyme.API.runtimeActivity!(true)

function tilde_assume!!(right, vn, vi)
    _, _ = DynamicPPL.invlink_with_logpdf(vi, vn, right)
    r = vi[vn]
    logp = 0.0
    vi.logp[] += logp
    return r, vi
end

x_varname1 = AbstractPPL.VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = AbstractPPL.VarName{:x}((Accessors.@optic _[:, 2]))

function satellite_model_matrix(__varinfo__::DynamicPPL.AbstractVarInfo, ::(DynamicPPL.TypeWrap){TV}) where {TV}
    P0 = vcat([0.1 0.0], [0.0 0.1])
    x = TV(undef, 2, 2)

    v1 = MvNormal([0.0, 0.0], P0)
    v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
    x[:, 1] .= v3

    v1 = MvNormal(x[:, 1], P0)
    v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
    x[:, 2] .= v4
    return nothing, __varinfo__
end

vi = DynamicPPL.VarInfo()
P0 = vcat([0.1 0.0], [0.0 0.1])
v1 = [0.0, 0.0]
DynamicPPL.push!!(vi, x_varname1, v1, MvNormal([0.0, 0.0], P0), DynamicPPL.SampleFromPrior())
DynamicPPL.push!!(vi, x_varname2, v1, MvNormal([0.0, 0.0], P0), DynamicPPL.SampleFromPrior())
vi = DynamicPPL.TypedVarInfo(vi)

function g(x)
    context = DynamicPPL.DefaultContext()
    vi_new = DynamicPPL.unflatten(vi, context, x)
    _, wrapper_new = satellite_model_matrix(vi_new, DynamicPPL.TypeWrap{Matrix{Real}}())
    return DynamicPPL.getlogp(wrapper_new)
end

x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)

end

Output:

┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @preprocess_julia___1_28309(i64 signext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="5091561456" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #32 !dbg !686 {
top:
  %1 = call {}*** @julia.get_pgcstack() #33
  %ptls_field5 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field5 to i64***
  %ptls_load67 = load i64**, i64*** %2, align 8, !tbaa !14
  %3 = getelementptr inbounds i64*, i64** %ptls_load67, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !18
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #33, !dbg !687
  fence syncscope("singlethread") seq_cst
  %4 = icmp sgt i64 %0, 0, !dbg !688
  br i1 %4, label %L6, label %L3, !dbg !689

L3:                                               ; preds = %top
  %5 = call noalias nonnull "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext %0) #34, !dbg !689
  %6 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* noundef nonnull @ijl_invoke, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5063918096 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5021202192 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %5, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5101639360 to {}*) to {} addrspace(10)*)) #35, !dbg !689
  %7 = addrspacecast {} addrspace(10)* %6 to {} addrspace(12)*, !dbg !689
  call void @ijl_throw({} addrspace(12)* %7) #36, !dbg !689
  unreachable, !dbg !689

L6:                                               ; preds = %top
  %current_task14 = getelementptr inbounds {}**, {}*** %1, i64 -14
  %current_task1 = bitcast {}*** %current_task14 to {}**
  %8 = call i64 @mpfr_custom_get_size(i64 %0) #33, !dbg !690
  %9 = add i64 %8, 7, !dbg !691
  %10 = and i64 %9, -8, !dbg !694
  %11 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* (i64, ...) @ijl_alloc_string(i64 %10) #33, !dbg !697
  %12 = addrspacecast {} addrspace(10)* %11 to {} addrspace(11)*, !dbg !698
  %13 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %12) #37, !dbg !698
  %14 = bitcast {}* %13 to {} addrspace(10)**, !dbg !698
  %15 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %14, i64 1, !dbg !698
  %string_ptr = ptrtoint {} addrspace(10)** %15 to i64, !dbg !698
  %newstruct = call noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5023625904 to {}*) to {} addrspace(10)*)) #38, !dbg !700
  %16 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*, !dbg !700
  %17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %16, i64 4, !dbg !700
  store {} addrspace(10)* null, {} addrspace(10)* addrspace(11)* %17, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !700
  store i64 %0, i64 addrspace(11)* %18, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %19 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*, !dbg !700
  %20 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 8, !dbg !700
  %memcpy_refined_dst = bitcast i8 addrspace(11)* %20 to i32 addrspace(11)*, !dbg !700
  store i32 1, i32 addrspace(11)* %memcpy_refined_dst, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %21 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 16, !dbg !700
  %memcpy_refined_dst3 = bitcast i8 addrspace(11)* %21 to i64 addrspace(11)*, !dbg !700
  store i64 -9223372036854775806, i64 addrspace(11)* %memcpy_refined_dst3, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %22 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 24, !dbg !700
  %23 = bitcast i8 addrspace(11)* %22 to i64 addrspace(11)*, !dbg !700
  store i64 %string_ptr, i64 addrspace(11)* %23, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %24 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 32, !dbg !700
  %25 = bitcast i8 addrspace(11)* %24 to {} addrspace(10)* addrspace(11)*, !dbg !700
  store atomic {} addrspace(10)* %11, {} addrspace(10)* addrspace(11)* %25 release, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  ret {} addrspace(10)* %newstruct, !dbg !701
}

 Type analysis state:
<analysis>
  %6 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* noundef nonnull @ijl_invoke, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5063918096 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5021202192 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %5, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5101639360 to {}*) to {} addrspace(10)*)) #35, !dbg !24: {[-1]:Pointer}, intvals: {}
  %newstruct = call noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5023625904 to {}*) to {} addrspace(10)*)) #38, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
{} addrspace(10)* null: {[-1]:Pointer, [-1,-1]:Anything}, intvals: {0,}
  %1 = call {}*** @julia.get_pgcstack() #33: {}, intvals: {}
  %3 = getelementptr inbounds i64*, i64** %ptls_load67, i64 2: {[-1]:Pointer}, intvals: {}
  %11 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* (i64, ...) @ijl_alloc_string(i64 %10) #33, !dbg !34: {[-1]:Pointer}, intvals: {}
  %8 = call i64 @mpfr_custom_get_size(i64 %0) #33, !dbg !25: {}, intvals: {}
  %5 = call noalias nonnull "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext %0) #34, !dbg !24: {[-1]:Pointer, [-1,-1]:Integer}, intvals: {}
  %ptls_field5 = getelementptr inbounds {}**, {}*** %1, i64 2: {}, intvals: {}
  %15 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %14, i64 1, !dbg !37: {[-1]:Pointer}, intvals: {}
  %17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %16, i64 4, !dbg !42: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %current_task14 = getelementptr inbounds {}**, {}*** %1, i64 -14: {}, intvals: {}
  %13 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %12) #37, !dbg !37: {[-1]:Pointer}, intvals: {}
  %current_task1 = bitcast {}*** %current_task14 to {}**: {}, intvals: {}
  %14 = bitcast {}* %13 to {} addrspace(10)**, !dbg !37: {[-1]:Pointer}, intvals: {}
  %2 = bitcast {}*** %ptls_field5 to i64***: {[-1]:Pointer}, intvals: {}
  %ptls_load67 = load i64**, i64*** %2, align 8, !tbaa !14: {}, intvals: {}
  %string_ptr = ptrtoint {} addrspace(10)** %15 to i64, !dbg !37: {[-1]:Pointer}, intvals: {}
  %16 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
  %safepoint = load i64*, i64** %3, align 8, !tbaa !18: {}, intvals: {}
  %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
  %12 = addrspacecast {} addrspace(10)* %11 to {} addrspace(11)*, !dbg !37: {[-1]:Pointer}, intvals: {}
i64 7: {[-1]:Integer}, intvals: {7,}
i64 -8: {[-1]:Integer}, intvals: {-8,}
i64 0: {[-1]:Anything}, intvals: {0,}
i64 %0: {[-1]:Integer}, intvals: {}
  %4 = icmp sgt i64 %0, 0, !dbg !21: {[-1]:Integer}, intvals: {}
  %9 = add i64 %8, 7, !dbg !26: {}, intvals: {}
  %10 = and i64 %9, -8, !dbg !30: {}, intvals: {}
</analysis>

Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Float@double} new: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
val:   %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !42 origin=  store i64 %0, i64 addrspace(11)* %18, align 8, !dbg !42, !tbaa !45, !alias.scope !49, !noalias !52
MethodInstance for (::Base.MPFR.var"#_#1#2")(::Int64, ::Type{BigFloat})

Caused by:
Stacktrace:
 [1] _BigFloat
   @ ./mpfr.jl:119
 [2] _
   @ ./mpfr.jl:129

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1996
  [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/dev/Enzyme/src/api.jl:192
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:3673
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5867
  [5] codegen
    @ ~/.julia/dev/Enzyme/src/compiler.jl:5143 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6674
  [7] _thunk
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6674 [inlined]
  [8] cached_compilation
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6712 [inlined]
  [9] (::Enzyme.Compiler.var"#28595#28596"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6781
 [10] JuliaContext(f::Enzyme.Compiler.var"#28595#28596"{…}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
 [12] #s2010#28594
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6732 [inlined]
 [13]
    @ Enzyme.Compiler ./none:0
 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [15] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Base.Broadcast.copyto_nonleaf!), df::Nothing, primal_1::Vector{…}, shadow_1_1::Vector{…}, primal_2::Base.Broadcast.Broadcasted{…}, shadow_2_1::Base.Broadcast.Broadcasted{…}, primal_3::Base.OneTo{…}, shadow_3_1::Nothing, primal_4::Int64, shadow_4_1::Nothing, primal_5::Int64, shadow_5_1::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:307
 [16] copy
    @ ./broadcast.jl:950 [inlined]
 [17] materialize
    @ ./broadcast.jl:903 [inlined]
 [18] sqmahal
    @ ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:267
 [19] _logpdf
    @ ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:143
 [20] logpdf
    @ ~/.julia/packages/Distributions/ji8PW/src/common.jl:263 [inlined]
 [21] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/abstract_varinfo.jl:856
 [22] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/abstract_varinfo.jl:850 [inlined]
 [23] tilde_assume!!
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:16 [inlined]
 [24] tilde_assume!!
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0 [inlined]
 [25] augmented_julia_tilde_assume___28212_inner_1wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [26] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [27] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::EnzymeCore.Const{…}, ::Type{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6223
 [28] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6111
 [29] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Main.MWE.tilde_assume!!), df::Nothing, primal_1::Distributions.MvNormal{…}, shadow_1_1::Distributions.MvNormal{…}, primal_2::AbstractPPL.VarName{…}, shadow_2_1::Nothing, primal_3::DynamicPPL.TypedVarInfo{…}, shadow_3_1::DynamicPPL.TypedVarInfo{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:311
 [30] satellite_model_matrix
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:35 [inlined]
 [31] satellite_model_matrix
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0 [inlined]
 [32] augmented_julia_satellite_model_matrix_27657_inner_1wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [33] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [34] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::EnzymeCore.Const{…}, ::Type{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Const{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6223
 [35] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6111
 [36] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Main.MWE.satellite_model_matrix), df::Nothing, primal_1::DynamicPPL.TypedVarInfo{…}, shadow_1_1::DynamicPPL.TypedVarInfo{…}, primal_2::DynamicPPL.TypeWrap{…}, shadow_2_1::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:311
 [37] g
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:50 [inlined]
 [38] augmented_julia_g_27631wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [39] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [40] enzyme_call
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6223 [inlined]
 [41] AugmentedForwardThunk
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6111 [inlined]
 [42] autodiff
    @ ~/.julia/dev/Enzyme/src/Enzyme.jl:253 [inlined]
 [43] autodiff(mode::EnzymeCore.ReverseMode{…}, f::typeof(Main.MWE.g), ::Type{…}, args::EnzymeCore.Duplicated{…})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:321
 [44] top-level scope
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:55
 [45] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [46] top-level scope
    @ REPL[11]:1
in expression starting at /Users/mhauru/projects/Enzyme-mwes/ref_minus_float/mwe.jl:4
Some type information was truncated. Use `show(err)` to see complete types.

On 0.0.132 and latest Enzyme.jl main.

I appreciate that's not very minimal, but I'm done for the day and wanted to put this up here. I can try to minimise further tomorrow if that's helpful.

This started life as a reproduction of #1608, but with the latest update to Enzyme stopped producing the error in #1608 and instead started spitting out the above.

wsmoses commented 2 weeks ago

Ah yeah BigFloat's aren't presently expected to work. It's on our to do list to add support for, but honestly it's not high priority at the moment.

mhauru commented 1 week ago

There shouldn't be any big floats coming into play in the MWE though, it's all multivariate normals with Float64s.

The problem seems to come from the same or related type instability as #1608 and https://github.com/TuringLang/Turing.jl/issues/2240. The TV variable in the above function has a type of Matrix{Real}, even though it should be inferrable to Matrix{Float64}, and if you fix it to Matrix{Float64} the issue goes away.

wsmoses commented 1 week ago

Hm it definitely thinks there’s a code path that could call a big float — even if practically it’s not used.

If you can minimize this a bit more I can work on making sure this error doesn’t happen

mhauru commented 1 week ago

MWE that only depends on Accessors and Distributions:

module MWE

import Accessors
import Distributions

using Enzyme

#Enzyme.API.runtimeActivity!(true)

struct VarName{sym,T}
    optic::T

    function VarName{sym}(optic=identity) where {sym}
        return new{sym,typeof(optic)}(optic)
    end
end

function Base.:(==)(x::VarName{symx}, y::VarName{symy}) where {symx,symy}
    return x.optic == y.optic && symx == symy
end

struct TypeWrap{T} end

struct VarInfo{Tval,Tlogp}
    vals::Tval
    logp::Base.RefValue{Tlogp}
end

function getindex(vi::VarInfo, vn::VarName)
    range = vn == x_varname1 ? (1:2) : (3:4)
    return copy(vi.vals[range])
end

VarInfo(old_vi::VarInfo, x) = VarInfo(x, Base.RefValue{eltype(x)}(old_vi.logp[]))

function tilde_assume!!(right, vn, vi)
    y = [1.0, 1.0]
    _ = Distributions.logpdf(right, y)
    r = getindex(vi, vn)
    logp = 0.0
    vi.logp[] += logp
    return r, vi
end

x_varname1 = VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = VarName{:x}((Accessors.@optic _[:, 2]))

function satellite_model_matrix(__varinfo__, ::(TypeWrap){TV}) where {TV}
    P0 = vcat([0.1 0.0], [0.0 0.1])
    x = TV(undef, 2, 2)

    v1 = Distributions.MvNormal([0.0, 0.0], P0)
    v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
    x[:, 1] .= v3

    v1 = Distributions.MvNormal(x[:, 1], P0)
    v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
    x[:, 2] .= v4
    return nothing, __varinfo__
end

vi = VarInfo(
    [0.0, 0.0, 0.0, 0.0],
    Base.RefValue{Float64}(0.0),
)

function g(x)
    vi_new = VarInfo(vi, x)
    _, wrapper_new = satellite_model_matrix(vi_new, TypeWrap{Matrix{Real}}())
    return wrapper_new.logp[]
end

x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)

end
wsmoses commented 1 day ago

Fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1658