EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 64 forks source link

GPU kernel error #1129

Closed jgreener64 closed 1 year ago

jgreener64 commented 1 year ago

I am on Julia 1.9.2, Enzyme main (31012f8b4b5f6d5699ded0bb9e7c184c4991312b), CUDA 4.4.1, Atomix 0.1.0, StaticArrays 1.6.5. The following used to work but now errors:

using Enzyme, CUDA, Atomix, StaticArrays, LinearAlgebra

struct PeriodicTorsion{N, T, E}
    periodicities::NTuple{N, Int}
    phases::NTuple{N, T}
    ks::NTuple{N, E}
end

function Base.zero(::PeriodicTorsion{N, T, E}) where {N, T, E}
    return PeriodicTorsion{N, T, E}(
        ntuple(_ -> 0      , N),
        ntuple(_ -> zero(T), N),
        ntuple(_ -> zero(E), N),
    )
end

function Base.:+(p1::PeriodicTorsion{N, T, E}, p2::PeriodicTorsion{N, T, E}) where {N, T, E}
    return PeriodicTorsion{N, T, E}(
        p1.periodicities,
        p1.phases .+ p2.phases,
        p1.ks .+ p2.ks,
    )
end

function torsion_angle(coords_i, coords_j, coords_k, coords_l)
    vec_ij = coords_j - coords_i
    vec_jk = coords_k - coords_j
    vec_kl = coords_l - coords_k
    cross_ij_jk = vec_ij × vec_jk
    cross_jk_kl = vec_jk × vec_kl
    θ = atan(
        dot(cross_ij_jk × cross_jk_kl, normalize(vec_jk)),
        dot(cross_ij_jk, cross_jk_kl),
    )
    return θ
end

function f(d::PeriodicTorsion{N}, coords_i, coords_j, coords_k, coords_l) where N
    θ = torsion_angle(coords_i, coords_j, coords_k, coords_l)
    k1 = d.ks[1]
    E = k1 + k1 * cos((d.periodicities[1] * θ) - d.phases[1])
    for i in 2:N
        k = d.ks[i]
        E += k + k * cos((d.periodicities[i] * θ) - d.phases[i])
    end
    return E
end

function kernel!(energy, coords_var, is_var, js_var, ks_var, ls_var, inters_var)
    coords = CUDA.Const(coords_var)
    is = CUDA.Const(is_var)
    js = CUDA.Const(js_var)
    ks = CUDA.Const(ks_var)
    ls = CUDA.Const(ls_var)
    inters = CUDA.Const(inters_var)

    inter_i = (blockIdx().x - 1) * blockDim().x + threadIdx().x

    @inbounds if inter_i <= length(is)
        i, j, k, l = is[inter_i], js[inter_i], ks[inter_i], ls[inter_i]
        pe = f(inters[inter_i], coords[i], coords[j], coords[k], coords[l])
        Atomix.@atomic :monotonic energy[1] += pe
    end
    return nothing
end

function grad_kernel!(energy, d_energy, coords, d_coords, is, js, ks, ls, inters, d_inters)
    Enzyme.autodiff_deferred(
        Enzyme.Reverse,
        kernel!,
        Const,
        Duplicated(energy, d_energy),
        Duplicated(coords, d_coords),
        Const(is),
        Const(js),
        Const(ks),
        Const(ls),
        Duplicated(inters, d_inters),
    )
    return nothing
end

pe_vec = CuArray([0.0])
d_pe_vec = CuArray([1.0])
coords = CuArray([
    SVector(1.0, 1.0, 1.0),
    SVector(2.0, 2.0, 2.0),
    SVector(3.0, 3.0, 3.0),
    SVector(4.0, 4.0, 4.0),
    SVector(5.0, 5.0, 5.0),
])
d_coords = zero(coords)
is = CuArray([1, 2])
js = CuArray([2, 3])
ks = CuArray([3, 4])
ls = CuArray([4, 5])
inters = CuArray([
    PeriodicTorsion((1, 2), (0.0, 0.0), (100.0, 100.0)),
    PeriodicTorsion((1, 2), (0.0, 0.0), (100.0, 100.0)),
])
d_inters = CuArray([
    PeriodicTorsion((0, 0), (0.0, 0.0), (0.0, 0.0)),
    PeriodicTorsion((0, 0), (0.0, 0.0), (0.0, 0.0)),
])

CUDA.@sync @cuda threads=128 kernel!(pe_vec, coords, is, js, ks, ls, inters) # Works

CUDA.@sync @cuda threads=128 grad_kernel!(
    pe_vec, d_pe_vec, coords, d_coords, is, js, ks, ls, inters, d_inters)

The error.txt and printall_error.txt are attached.

vchuravy commented 1 year ago
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:5106: void AdjointGenerator<AugmentedReturnType>::recursivelyHandleSubfunction(llvm::CallInst&, llvm::Function*, const std::vector<bool>&, bool, DIFFE_TYPE, bool) [with AugmentedReturnType = const AugmentedReturn*]: Assertion `subdata' failed.
wsmoses commented 1 year ago

@jgreener64 can you retry on main now that I added: https://github.com/EnzymeAD/Enzyme.jl/pull/1131

jgreener64 commented 1 year ago

The error is similar to before, error_2.txt and printall_error_2.txt.

wsmoses commented 1 year ago

Yeah it's the same and I see why it didn't fix, can you try: https://github.com/EnzymeAD/Enzyme.jl/pull/1132

jgreener64 commented 1 year ago

A similar error is thrown: error_3.txt and printall_error_3.txt.

jgreener64 commented 1 year ago

The MWE is fixed on latest main. Now I get a related error, which seems to be due to acos. asin instead also errors. See error.txt and printall_error.txt.

using Enzyme, CUDA, Atomix, StaticArrays, LinearAlgebra

struct HarmonicAngle{K, D}
    k::K
    θ0::D
end

Base.zero(::HarmonicAngle{K, D}) where {K, D} = HarmonicAngle(zero(K), zero(D))

Base.:+(a1::HarmonicAngle, a2::HarmonicAngle) = HarmonicAngle(a1.k + a2.k, a1.θ0 + a2.θ0)

function f(a::HarmonicAngle, coords_i, coords_j, coords_k)
    vec_ji = coords_i - coords_j
    vec_jk = coords_k - coords_j
    θ = acos(dot(vec_ji, vec_jk) / (norm(vec_ji) * norm(vec_jk)))
    return (a.k / 2) * (θ - a.θ0) ^ 2
end

function kernel!(energy, coords_var, is_var, js_var, ks_var, inters_var)
    coords = CUDA.Const(coords_var)
    is = CUDA.Const(is_var)
    js = CUDA.Const(js_var)
    ks = CUDA.Const(ks_var)
    inters = CUDA.Const(inters_var)

    inter_i = (blockIdx().x - 1) * blockDim().x + threadIdx().x

    @inbounds if inter_i <= length(is)
        i, j, k = is[inter_i], js[inter_i], ks[inter_i]
        pe = f(inters[inter_i], coords[i], coords[j], coords[k])
        Atomix.@atomic :monotonic energy[1] += pe
    end
    return nothing
end

function grad_kernel!(energy, d_energy, coords, d_coords, is, js, ks, inters, d_inters)
    Enzyme.autodiff_deferred(
        Enzyme.Reverse,
        kernel!,
        Const,
        Duplicated(energy, d_energy),
        Duplicated(coords, d_coords),
        Const(is),
        Const(js),
        Const(ks),
        Duplicated(inters, d_inters),
    )
    return nothing
end

pe_vec = CuArray([0.0])
d_pe_vec = CuArray([1.0])
coords = CuArray([
    SVector(1.0, 1.0, 1.0),
    SVector(2.0, 2.1, 2.0),
    SVector(3.0, 3.2, 3.3),
    SVector(4.0, 4.1, 4.5),
])
d_coords = zero(coords)
is = CuArray([1, 2])
js = CuArray([2, 3])
ks = CuArray([3, 4])
inters = CuArray([
    HarmonicAngle(100.0, deg2rad(90.0)),
    HarmonicAngle(100.0, deg2rad(90.0)),
])
d_inters = CuArray([
    HarmonicAngle(0.0, 0.0),
    HarmonicAngle(0.0, 0.0),
])

CUDA.@sync @cuda threads=128 kernel!(pe_vec, coords, is, js, ks, inters) # Works

CUDA.@sync @cuda threads=128 grad_kernel!(
    pe_vec, d_pe_vec, coords, d_coords, is, js, ks, inters, d_inters)
jgreener64 commented 12 months ago

I'm not sure if Enzyme_jll v0.0.93 is valid with Enzyme.jl main (4329953f6f707982f841c6ed6bc164dd5ae65094) since the Project.toml hasn't been bumped. However if I jump the gun and try it then I get a new error with the above code, might be a problem with norm:

ERROR: InvalidIRError: compiling MethodInstance for grad_kernel!(::CuDeviceVector{Float64, 1}, ::CuDeviceVector{Float64, 1}, ::CuDeviceVector{SVector{3, Float64}, 1}, ::CuDeviceVector{SVector{3, Float64}, 1}, ::CuDeviceVector{Int64, 1}, ::CuDeviceVector{Int64, 1}, ::CuDeviceVector{Int64, 1}, ::CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}, ::CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}) resulted in invalid LLVM IR
Reason: unsupported call through a literal pointer (call to )
Stacktrace:
  [1] #isnan
    @ ~/.julia/dev/CUDA/src/device/intrinsics/math.jl:205
  [2] #max
    @ ~/.julia/dev/CUDA/src/device/intrinsics/math.jl:338
  [3] maxabs_nested
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:243
  [4] macro expansion
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:257
  [5] _norm_scaled
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:249
  [6] macro expansion
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:279
  [7] _norm
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:266
  [8] norm
    @ ~/.julia/packages/StaticArrays/yXGNL/src/linalg.jl:265
  [9] f
    @ ./REPL[5]:4
 [10] multiple call sites
    @ unknown:0
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/validation.jl:147
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:440 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:439 [inlined]
  [5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:92
  [6] emit_llvm
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:86 [inlined]
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:129
  [8] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:106
  [9] compile
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:98 [inlined]
 [10] #1075
    @ ~/.julia/dev/CUDA/src/compiler/compilation.jl:247 [inlined]
 [11] JuliaContext(f::CUDA.var"#1075#1077"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [12] compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/dev/CUDA/src/compiler/compilation.jl:246
 [13] actual_compilation(cache::Dict{Any, CuFunction}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:125
 [14] cached_compilation(cache::Dict{Any, CuFunction}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:103
 [15] macro expansion
    @ ~/.julia/dev/CUDA/src/compiler/execution.jl:373 [inlined]
 [16] macro expansion
    @ ./lock.jl:267 [inlined]
 [17] cufunction(f::typeof(grad_kernel!), tt::Type{Tuple{CuDeviceVector{Float64, 1}, CuDeviceVector{Float64, 1}, CuDeviceVector{SVector{3, Float64}, 1}, CuDeviceVector{SVector{3, Float64}, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}, CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/dev/CUDA/src/compiler/execution.jl:368
 [18] cufunction(f::typeof(grad_kernel!), tt::Type{Tuple{CuDeviceVector{Float64, 1}, CuDeviceVector{Float64, 1}, CuDeviceVector{SVector{3, Float64}, 1}, CuDeviceVector{SVector{3, Float64}, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{Int64, 1}, CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}, CuDeviceVector{HarmonicAngle{Float64, Float64}, 1}}})
    @ CUDA ~/.julia/dev/CUDA/src/compiler/execution.jl:365
 [19] macro expansion
    @ ~/.julia/dev/CUDA/src/compiler/execution.jl:104 [inlined]
 [20] top-level scope
    @ ~/.julia/dev/CUDA/src/utilities.jl:35
 [21] top-level scope
    @ ~/.julia/dev/CUDA/src/initialization.jl:208
jgreener64 commented 12 months ago

And the printall_error.txt.