EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Enzyme troubles with ElectrochemicalKinetics #931

Closed rkurchin closed 1 year ago

rkurchin commented 1 year ago

After @vchuravy's amazing PASC talk, I was convinced to finally take the plunge and convert my package, ElectrochemicalKinetics.jl, to Enzyme – it does a lot of scalar AD so I anticipate a nice performance boost. And indeed, right out of the box, on a simpler model from the package, I see about the same speed as ForwardDiff and about 90% fewer allocations 🎉

However, on a more complicated model that requires computation of an integral, there seems to be some type inference issue that's borking things. Here's the code to run to reproduce:

using ElectrochemicalKinetics
using Enzyme

bv = ButlerVolmer(300, 0.5) # simple model
mhcd = MarcusHushChidseyDOS(10000, 0.3, string(dirname(pathof(ElectrochemicalKinetics)), "/../data/DOSes/Cu_111_dos.txt")) # more complicated one

# sanity check that just evaluating the functions works
V_test = 0.2
rate_constant(V_test, bv)
rate_constant(V_test, mhcd)

# okay now differentiate...this works fine (in both forward and reverse mode):
autodiff(Forward, rate_constant, Duplicated, Duplicated(V_test, 1.0), Const(bv))

# but this crashes the REPL:
autodiff(Forward, rate_constant, Duplicated, Duplicated(V_test, 1.0), Const(mhcd))

The output of this on Julia 1.7 for both Enzyme 0.11.2 and #main are attached here. @DhairyaLGandhi (who's been my AD wizard with this package generally, many thanks as always) also tried it on 1.9 and saw something similar (also attached). out.txt out_enzymemain.txt out_19.txt

Many thanks in advance for any advice, and of course LMK if any other information would be helpful!

wsmoses commented 1 year ago
define nonnull {} addrspace(10)* @preprocess_julia_rate_constant_2350_inner.1(double %0, {} %1) local_unnamed_addr #8 !dbg !44 {
julia_rate_constant_2350_inner.exit:
  %2 = alloca {}, align 8, !dbg !50
  %3 = call {}*** @julia.get_pgcstack() #9
  call void @llvm.dbg.value(metadata double %0, metadata !47, metadata !DIExpression()) #9, !dbg !51
  call void @llvm.dbg.declare(metadata {} addrspace(10)* %4, metadata !48, metadata !DIExpression(DW_OP_deref)) #9, !dbg !53
  %4 = addrspacecast {}* %2 to {} addrspace(10)*, !dbg !50
  %current_task3.i2 = getelementptr inbounds {}**, {}*** %3, i64 -13, !dbg !53
  %current_task3.i = bitcast {}*** %current_task3.i2 to {}**, !dbg !53
  %5 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 140737069053024 to {}*) to {} addrspace(10)*)) #10, !dbg !53
  %6 = bitcast {} addrspace(10)* %5 to double addrspace(10)*, !dbg !53
  store double %0, double addrspace(10)* %6, align 8, !dbg !53, !tbaa !28, !alias.scope !34, !noalias !54
  %7 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @ijl_apply_generic, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 140732114786976 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140737212136096 to {}*) to {} addrspace(10)*), {} addrspace(10)* undef, {} addrspace(10)* undef, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140737067902800 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140737050554576 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140737050488544 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140737050948816 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140736992754816 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140732114719728 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %5, {} addrspace(10)* %4) #11, !dbg !53
  ret {} addrspace(10)* %7, !dbg !50
}

@vchuravy why would julia generate a {} arg?

wsmoses commented 1 year ago
julia> using Enzyme
using 
julia> using ElectrochemicalKinetics
using LLVM

julia> using LLVM

julia> ctx=LLVM.Context()
Context(Ptr{LLVM.API.LLVMOpaqueContext} @0x000055fc86d9bf30)

julia> convert(LLVMType, Float64; ctx)
double

julia> convert(LLVMType, Float64; ctx, allow_boxed=true)
double

julia> convert(LLVMType, MarcusHushChidseyDOS{Float64}; ctx, allow_boxed=true)
{ double, double, [4 x {} addrspace(10)*] }
wsmoses commented 1 year ago
(actualRetType, args, prargs, entry_f) = 
(Any,
  Any[(cc = GPUCompiler.BITS_VALUE, typ = Float64, arg_i = 2, codegen = (typ = double, i = 1)), (cc = GPUCompiler.BITS_REF, typ = MarcusHushChidseyDOS{Float64}, arg_i = 3, codegen = (typ = {} addrspace(10)*, i = 2))]
,
  Any[(cc = GPUCompiler.GHOST, typ = typeof(rate_constant), arg_i = 1),
  (cc = GPUCompiler.BITS_VALUE, typ = Float64, arg_i = 2, codegen = (typ = double, i = 1)),
  (cc = GPUCompiler.BITS_REF, typ = MarcusHushChidseyDOS{Float64}, arg_i = 3, codegen = (typ = {} addrspace(10)*, i = 2))]

  , 
define nonnull {} addrspace(10)* @julia_rate_constant_1031(double %0, {} addrspace(10)* noundef nonnull readonly align 8 dereferenceable(48) %1) #0 !dbg !4 {
wsmoses commented 1 year ago

Julia itself forces it to be non-stable here

wsmoses commented 1 year ago
julia> @code_llvm rate_constant(V_test, mhcd)
;  @ /home/wmoses/.julia/packages/ElectrochemicalKinetics/xwjPU/src/kinetic_models/MarcusHushChidseyDOS.jl:53 within `rate_constant`
; Function Attrs: sspstrong
define nonnull {}* @julia_rate_constant_612(double %0, {}* noundef nonnull readonly align 8 dereferenceable(48) %1) #0 {
top:
  %2 = alloca [11 x {}*], align 8
  %gcframe21 = alloca [3 x {}*], align 16
  %gcframe21.sub = getelementptr inbounds [3 x {}*], [3 x {}*]* %gcframe21, i64 0, i64 0
  %.sub = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 0
  %3 = bitcast [3 x {}*]* %gcframe21 to i8*
  call void @llvm.memset.p0i8.i32(i8* noundef nonnull align 16 dereferenceable(24) %3, i8 0, i32 24, i1 false)
  %thread_ptr = call i8* asm "movq %fs:0, $0", "=r"() #8
  %ppgcstack_i8 = getelementptr i8, i8* %thread_ptr, i64 -8
  %ppgcstack = bitcast i8* %ppgcstack_i8 to {}****
  %pgcstack = load {}***, {}**** %ppgcstack, align 8
;  @ /home/wmoses/.julia/packages/ElectrochemicalKinetics/xwjPU/src/kinetic_models/MarcusHushChidseyDOS.jl within `rate_constant`
  %4 = bitcast [3 x {}*]* %gcframe21 to i64*
  store i64 4, i64* %4, align 16
  %5 = getelementptr inbounds [3 x {}*], [3 x {}*]* %gcframe21, i64 0, i64 1
  %6 = bitcast {}** %5 to {}***
  %7 = load {}**, {}*** %pgcstack, align 8
  store {}** %7, {}*** %6, align 8
  %8 = bitcast {}*** %pgcstack to {}***
  store {}** %gcframe21.sub, {}*** %8, align 8
;  @ /home/wmoses/.julia/packages/ElectrochemicalKinetics/xwjPU/src/kinetic_models/MarcusHushChidseyDOS.jl:53 within `rate_constant`
; ┌ @ Base.jl:37 within `getproperty`
   %9 = bitcast {}* %1 to { double, double, [4 x {}*] }*
   %.elt = getelementptr inbounds { double, double, [4 x {}*] }, { double, double, [4 x {}*] }* %9, i64 0, i32 2, i64 0
   %.unpack = load {}*, {}** %.elt, align 8
   %.not = icmp eq {}* %.unpack, null
   br i1 %.not, label %fail, label %pass2

fail:                                             ; preds = %top
   call void @ijl_throw({}* inttoptr (i64 139667828664448 to {}*))
   unreachable

pass2:                                            ; preds = %top
   %.elt16 = getelementptr inbounds { double, double, [4 x {}*] }, { double, double, [4 x {}*] }* %9, i64 0, i32 2, i64 3
   %.unpack17 = load {}*, {}** %.elt16, align 8
   %.elt6 = getelementptr inbounds { double, double, [4 x {}*] }, { double, double, [4 x {}*] }* %9, i64 0, i32 2, i64 2
   %.unpack7 = load {}*, {}** %.elt6, align 8
; â””
  %ptls_field22 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %10 = bitcast {}*** %ptls_field22 to i8**
  %ptls_load2324 = load i8*, i8** %10, align 8
  %11 = call noalias nonnull {}* @ijl_gc_pool_alloc(i8* %ptls_load2324, i32 1392, i32 16) #4
  %12 = bitcast {}* %11 to i64*
  %13 = getelementptr inbounds i64, i64* %12, i64 -1
  store atomic i64 139667829814368, i64* %13 unordered, align 8
  %14 = bitcast {}* %11 to double*
  store double %0, double* %14, align 8
  %15 = getelementptr inbounds [3 x {}*], [3 x {}*]* %gcframe21, i64 0, i64 2
  store {}* %11, {}** %15, align 16
  store {}* inttoptr (i64 139667972897440 to {}*), {}** %.sub, align 8
  %16 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 1
  store {}* %.unpack7, {}** %16, align 8
  %17 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 2
  store {}* %.unpack17, {}** %17, align 8
  %18 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 3
  store {}* inttoptr (i64 139667828664144 to {}*), {}** %18, align 8
  %19 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 4
  store {}* inttoptr (i64 139667811315920 to {}*), {}** %19, align 8
  %20 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 5
  store {}* inttoptr (i64 139667811249888 to {}*), {}** %20, align 8
  %21 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 6
  store {}* inttoptr (i64 139667811710160 to {}*), {}** %21, align 8
  %22 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 7
  store {}* inttoptr (i64 139667753516160 to {}*), {}** %22, align 8
  %23 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 8
  store {}* inttoptr (i64 139662875481072 to {}*), {}** %23, align 8
  %24 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 9
  store {}* %11, {}** %24, align 8
  %25 = getelementptr inbounds [11 x {}*], [11 x {}*]* %2, i64 0, i64 10
  store {}* %1, {}** %25, align 8
  %26 = call nonnull {}* @ijl_apply_generic({}* inttoptr (i64 139662875548320 to {}*), {}** nonnull %.sub, i32 11)
  %27 = load {}*, {}** %5, align 8
  %28 = bitcast {}*** %pgcstack to {}**
  store {}* %27, {}** %28, align 8
  ret {}* %26
}
rkurchin commented 1 year ago

Thanks for looking into this! For my own reference, what exactly is causing this {} arg thing to happen? I'm not able to figure out from the traces here where in the rate_constant function it's happening, but if there's a not-too-invasive fix on my end, I'm totally open to that, too!

wsmoses commented 1 year ago

Basically there is a bug in GPUCompiler such that it incorrectly parsed the calling convention of that type.

wsmoses commented 1 year ago

on main [and the latest GPUCompiler.jl tag], this specific should be fixed [but your total code hits a different issue that requires a jll bump]. Will reopen now

wsmoses commented 1 year ago

Should not be fixed on main, but please reopen if it persists.

Note that you will get an error message saying that there is a mixedTypeActivity. You can either enable runtimeActivity like the error message says, or rewrite the generator (which produces the mixed activity closure) from the backtrace that the error warns is the source of the issue (better in the long run).

rkurchin commented 1 year ago

So the backtrace points me to the line that is actually doing the quadrature sum: sum(w .* f.(n)) (full context)

I'm not sure I completely understand what the issue is/what I need to change to avoid this issue, can you point me in the right direction?

Full backtrace:

julia> autodiff(Forward, rate_constant, Duplicated, Duplicated(0.2, 1.0), Const(mhcd))
ERROR: Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %8, {} addrspace(10)** %.fca.2.gep, align 8, !dbg !557, !noalias !323 const val: {} addrspace(10)* %8
Type tree: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Pointer, [-1,24]:Pointer, [-1,32]:Pointer, [-1,40]:Pointer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:683
  [2] _broadcast_getindex
    @ ./broadcast.jl:656
  [3] _getindex
    @ ./broadcast.jl:680
  [4] _getindex
    @ ./broadcast.jl:679
  [5] _broadcast_getindex
    @ ./broadcast.jl:655
  [6] getindex
    @ ./broadcast.jl:610
  [7] copy
    @ ./broadcast.jl:912
  [8] materialize
    @ ./broadcast.jl:873
  [9] #rate_constant#35
    @ /Volumes/Data/git_repos/ElectrochemicalKinetics.jl/src/kinetic_models/MarcusHushChidseyDOS.jl:81
 [10] #rate_constant#35
    @ /Volumes/Data/git_repos/ElectrochemicalKinetics.jl/src/kinetic_models/MarcusHushChidseyDOS.jl:0

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/Auoie/src/compiler.jl:2790

My best read of the bit of the docs that it links to is that if you're modifying the variable you pass in as Const, even if you don't need the derivative with respect to it, then that's what causes this. But in this case, I'm using parameters of the mhcd object, but not modifying it, so I don't see how it's analogous. Maybe I'm misunderstanding the source of the issue, though?

rkurchin commented 1 year ago

Also, when I try the other workaround (Enzyme.API.runtimeActivity!(true)), I still get the same error.

wsmoses commented 1 year ago

re runtime Activity. You have to do that right after loading Enzyme, before any autodiff calls are compiled (its a global property)

wsmoses commented 1 year ago

The relevant part of the docs I think are here: https://enzyme.mit.edu/julia/stable/api/#Enzyme.API.runtimeActivity!-Tuple{Bool}

In this specific case what's happening is that a constant variable (presuming mhcd of a child of that variable) is being stored into an active data structure (which can lead to the issues described). In this case my guess is what's happening is that the broadcast operation in sum is creating a closure of all variable there, which includes some constant variables, and some active variables -- creating the error.

My guess is that if you wrote it with a for loop, the issue may be alleviated (and perhaps run faster).

rkurchin commented 1 year ago

I rewrote that broadcast sum line like so:

s = zero(eltype(w))
for i=1:length(w)
    @inbounds s+= w[i] * f(n[i])
end
s

And I still get the error, though the stacktrace is a lot shorter:

julia> autodiff(Forward, rate_constant, Duplicated, Duplicated(0.2, 1.0), Const(mhcd))
ERROR: Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %8, {} addrspace(10)** %.fca.2.gep, align 8, !dbg !150, !noalias !151 const val: {} addrspace(10)* %8
Type tree: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Pointer, [-1,24]:Pointer, [-1,32]:Pointer, [-1,40]:Pointer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] #rate_constant#35
   @ /Volumes/Data/git_repos/ElectrochemicalKinetics.jl/src/kinetic_models/MarcusHushChidseyDOS.jl:83
 [2] #rate_constant#35
   @ /Volumes/Data/git_repos/ElectrochemicalKinetics.jl/src/kinetic_models/MarcusHushChidseyDOS.jl:0

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/Auoie/src/compiler.jl:2790

FWIW, the immediately preceding lines to the section copied above are:

n, w = scale(E_min, E_max)
f = integrand(model, V_app, args...; T = T, kwargs...)

Is it possible that those things being generated within the function are causing the issue?

rkurchin commented 1 year ago

I also just noticed that I get different values for the derivative of the integrand (which does run in both cases) with the runtime activity turned on vs. off, so I assume the one with it on is the one to trust in general...

wsmoses commented 1 year ago

Oh if you get different results (Aka the off one doesn't throw an error but returns a different result), that is a definite issue and please open another issue for it!