EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Autodiff Deferred Thunk is broken #1417

Closed wsmoses closed 1 month ago

wsmoses commented 2 months ago

@vchuravy @michel2323 I'm not sure what's happening here. The non-deferred version succeeeds. This is a minimization of the segfault from https://github.com/JuliaGPU/KernelAbstractions.jl/pull/476

using KernelAbstractions
using Test
using Enzyme
using EnzymeCore
# using KernelAbstractions.EnzymeExt
Enzyme.API.printall!(true)

@kernel function square!(A)
    I = @index(Global, Linear)
    @inbounds A[I] *= A[I]
end

    A = Array{Float64}(undef, 64)
    dA = Array{Float64}(undef, 64)

    A .= (1:1:64)
    dA .= 1

    import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU

    function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT}
        TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
        @show TapeType
        forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)

        # Non deferred works
        # forward, reverse = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
        forward(Const(f), Const(ctx), args...)[1]
        return nothing
    end

    ndrange = (1,)
    workgroupsize = nothing
    func = Enzyme.Const(square!(CPU()))
        kernel = func.val
        f = kernel.f

        ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)

        # TODO autodiff_deferred on the func.val
        ModifiedBetween = Val((false, false, true))

        aug_kernel = similar(kernel, aug_fwd)

        aug_kernel(f, ModifiedBetween, Duplicated(A, dA); ndrange, workgroupsize)
vchuravy commented 2 months ago

Do you have a backtrace for the fault?

wsmoses commented 2 months ago

[307161] signal (11.1): Segmentation fault
in expression starting at /home/wmoses/git/KernelAbstractions.jl/test.jl:47
cpu_square! at /home/wmoses/git/KernelAbstractions.jl/src/macros.jl:285 [inlined]
cpu_square! at ./none:0 [inlined]
cpu_square! at ./none:0 [inlined]
augmented_julia_cpu_square__3572_inner_1wrap at ./none:0
macro expansion at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5656 [inlined]
enzyme_call at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5334 [inlined]
AugmentedForwardThunk at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5223 [inlined]
aug_fwd at /home/wmoses/git/KernelAbstractions.jl/test.jl:26 [inlined]
__thread_run at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:115
unknown function (ip: 0x7d8f18dc929e)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
__run at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:82
unknown function (ip: 0x7d903bc8900d)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
#_#16 at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:44
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:768
Kernel at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:37
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
include_string at ./loading.jl:2076
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
_include at ./loading.jl:2136
include at ./Base.jl:495
jfptr_include_46403.1 at /home/wmoses/git/Enzyme.jl/julia-1.10.2/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
exec_options at ./client.jl:318
_start at ./client.jl:552
jfptr__start_82738.1 at /home/wmoses/git/Enzyme.jl/julia-1.10.2/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x7d903d029d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 24285604 (Pool: 24248754; Big: 36850); GC: 36
Segmentation fault (core dumped)
vchuravy commented 1 month ago

KA free reproducer

function kernel(len, A)
    for i in 1:len
        A[i] *= A[i]
    end 
end

using Enzyme, EnzymeCore

A = Array{Float64}(undef, 64)
dA = Array{Float64}(undef, 64)

A .= (1:1:64)
dA .= 1

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT}
    TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
    forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)

    # Non deferred works
    # forward, reverse = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
    forward(Const(f), Const(ctx), args...)[1]
    return nothing
end

ModifiedBetween = Val((false, false, true))

aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA))