EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Error from untaken branch with Float32 #886

Closed jgreener64 closed 1 year ago

jgreener64 commented 1 year ago

I am on Julia 1.9.0, Enzyme 9487eb8349fd7d907403533b53c23822313897bb and StaticArrays v1.5.25. Apologies for the slightly strange example. This errors:

using Enzyme, StaticArrays

vector_1D(c1, c2, side_length) = c2 - c1

function f!(pe_vec, coords, boundary, n_threads, ::Val{T}) where T
    pe_sum = zero(T)
    if n_threads > 1 # This branch is not taken but the code in it still causes problems
        pe_sum_chunks = [zero(T) for _ in 1:n_threads]
        Threads.@threads for thread_id in 1:n_threads
            dr = vector_1D.(c1, c2, boundary)
            pe_sum_chunks[thread_id] += sum(dr)
        end
        pe_sum += sum(pe_sum_chunks)
    end
    pe_vec[1] = pe_sum
    return nothing
end

T = Float32
pe_vec = [zero(T)]
coords = [SVector(T(1.0), T(1.0), T(1.0)), SVector(T(2.0), T(2.0), T(2.0))]
boundary = SVector(T(4.0), T(4.0), T(4.0))
n_threads = 1
autodiff(
    Enzyme.Reverse,
    f!,
    Const,
    Duplicated(pe_vec, [one(T)]),
    Duplicated(coords, zero(coords)),
    Const(boundary),
    Const(n_threads),
    Const(Val(T)),
)
ERROR: MethodError: no method matching any_jltypes(::Type{Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Tuple{Tuple{Float32, Float32, Float32}}, Any, Tuple{UInt64, UInt64}}}, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11"), Symbol("12"), Symbol("13"), Symbol("14"), Symbol("15"), Symbol("16"), Symbol("17"), Symbol("18"), Symbol("19"), Symbol("20"), Symbol("21"), Symbol("22"), Symbol("23"), Symbol("24"), Symbol("25"), Symbol("26"), Symbol("27"), Symbol("28"), Symbol("29"), Symbol("30"), Symbol("31"), Symbol("32"), Symbol("33"), Symbol("34")), Tuple{Any, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Any, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, UInt32, Any, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Any, Core.LLVMPtr{Core.LLVMPtr{Tuple{}, 0}, 0}, Any, Core.LLVMPtr{Float32, 0}}}, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Tuple{Tuple{Float32, Float32, Float32}}, Any, Tuple{UInt64, UInt64}}}}})

Closest candidates are:
  any_jltypes(::Union{LLVM.ArrayType, LLVM.VectorType})
   @ Enzyme ~/.julia/dev/Enzyme/src/compiler.jl:6025
  any_jltypes(::LLVM.FloatingPointType)
   @ Enzyme ~/.julia/dev/Enzyme/src/compiler.jl:6027
  any_jltypes(::Type{NamedTuple{A, B}}) where {A, B}
   @ Enzyme ~/.julia/dev/Enzyme/src/compiler.jl:6037
  ...

Stacktrace:
  [1] threadsfor_augfwd(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:3491
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/dev/Enzyme/src/api.jl:128
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{6, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:7691
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:9047
  [5] codegen
    @ ~/.julia/dev/Enzyme/src/compiler.jl:8655 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:9559
  [7] _thunk
    @ ~/.julia/dev/Enzyme/src/compiler.jl:9556 [inlined]
  [8] cached_compilation
    @ ~/.julia/dev/Enzyme/src/compiler.jl:9594 [inlined]
  [9] #s301#222
    @ ~/.julia/dev/Enzyme/src/compiler.jl:9652 [inlined]
 [10] var"#s301#222"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] thunk(::Val{0x00000000000082d8}, ::Type{Const{typeof(f!)}}, ::Type{Const}, tt::Type{Tuple{Duplicated{Vector{Float32}}, Duplicated{Vector{SVector{3, Float32}}}, Const{SVector{3, Float32}}, Const{Int64}, Const{Val{Float32}}}}, ::Val{Enzyme.API.DEM_ReverseModeCombined}, ::Val{1}, ::Val{(false, false, false, false, false, false)}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:9611
 [13] autodiff(::EnzymeCore.ReverseMode{false}, ::Const{typeof(f!)}, ::Type{Const}, ::Duplicated{Vector{Float32}}, ::Vararg{Any})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:199
 [14] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(f!), ::Type, ::Duplicated{Vector{Float32}}, ::Vararg{Any})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:214
 [15] top-level scope
    @ REPL[8]:1

Any of the following removes the error:

wsmoses commented 1 year ago

@jgreener64 see if https://github.com/EnzymeAD/Enzyme.jl/pull/889 fixes it for you?

jgreener64 commented 1 year ago

Yes that fixes it, thanks.