EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

sqrt in CUDA kernel gives error #426

Closed jgreener64 closed 2 years ago

jgreener64 commented 2 years ago

When I run the following I get an error. I am on Julia 1.8.0, Enzyme 0.10.4 and CUDA 3.12.0.

using Enzyme, CUDA, StaticArrays

struct Atom{T}
    σ::T
    ϵ::T
end

function force(c1, c2, a1, a2)
    dr = c2 - c1
    invr2 = inv(sum(abs2, dr))
    σ = (a1.σ + a2.σ) / 2
    ϵ = sqrt(a1.ϵ * a2.ϵ)
    six_term = (σ^2 * invr2) ^ 3
    f = (24 * ϵ * invr2) * (2 * six_term ^ 2 - six_term)
    return f * dr
end

function kernel(C, A)
    i = threadIdx().x
    if i <= (length(C) - 1)
        f = force(C[i], C[i+1], A[i], A[i+1])
    end
    return nothing
end

function grad_kernel(C, dC, A, dA)
    Enzyme.autodiff_deferred(kernel, Const, Duplicated(C, dC), Duplicated(A, dA))
    return nothing
end

C  = cu(rand(SVector{3, Float32}, 10))
dC = cu(zero(C))
A  = cu([Atom(1.0f0, 1.0f0) for _ in 1:10])
dA = cu([Atom(0.0f0, 0.0f0) for _ in 1:10])

@cuda threads=length(C) grad_kernel(C, dC, A, dA)

Interestingly if I replace the ϵ = sqrt(a1.ϵ * a2.ϵ) line with ϵ = sqrt(a1.ϵ) or ϵ = sqrt(a2.ϵ) then it works.

The error is:

ERROR: LoadError: Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_force_4161([1 x [3 x float]]* noalias nocapture nofree noundef nonnull writeonly sret([1 x [3 x float]]) align 4 dereferenceable(12) %0, [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %1, [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %2, [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %3, [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %4) unnamed_addr #11 !dbg !439 {
top:
  %5 = call {}*** @julia.get_pgcstack() #12
  %6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 0, !dbg !440
  %7 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 0, !dbg !440
  %8 = load float, float addrspace(11)* %6, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %9 = load float, float addrspace(11)* %7, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %10 = fsub float %8, %9, !dbg !447
  %11 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 1, !dbg !440
  %12 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 1, !dbg !440
  %13 = load float, float addrspace(11)* %11, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %14 = load float, float addrspace(11)* %12, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %15 = fsub float %13, %14, !dbg !447
  %16 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 2, !dbg !440
  %17 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 2, !dbg !440
  %18 = load float, float addrspace(11)* %16, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %19 = load float, float addrspace(11)* %17, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %20 = fsub float %18, %19, !dbg !447
  %21 = fmul float %10, %10, !dbg !448
  %22 = fmul float %15, %15, !dbg !448
  %23 = fadd float %21, %22, !dbg !456
  %24 = fmul float %20, %20, !dbg !448
  %25 = fadd float %23, %24, !dbg !456
  %26 = fdiv float 1.000000e+00, %25, !dbg !457
  %27 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 0, !dbg !459
  %28 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 0, !dbg !459
  %29 = load float, float addrspace(11)* %27, align 4, !dbg !461, !tbaa !67, !invariant.load !4
  %30 = load float, float addrspace(11)* %28, align 4, !dbg !461, !tbaa !67, !invariant.load !4
  %31 = fadd float %29, %30, !dbg !461
  %32 = fmul float %31, 5.000000e-01, !dbg !462
  %33 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 1, !dbg !464
  %34 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 1, !dbg !464
  %35 = load float, float addrspace(11)* %33, align 4, !dbg !466, !tbaa !67, !invariant.load !4
  %36 = load float, float addrspace(11)* %34, align 4, !dbg !466, !tbaa !67, !invariant.load !4
  %37 = fmul float %35, %36, !dbg !466
  %38 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0)) #13, !dbg !467
  %.not = icmp eq i32 %38, 0, !dbg !467
  br i1 %.not, label %45, label %39, !dbg !467

39:                                               ; preds = %top
  %40 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #13, !dbg !467
  %.not9 = icmp eq i32 %40, 0, !dbg !467
  br i1 %.not9, label %43, label %41, !dbg !467

41:                                               ; preds = %39
  %42 = call float @llvm.nvvm.sqrt.rn.ftz.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

43:                                               ; preds = %39
  %44 = call float @llvm.nvvm.sqrt.approx.ftz.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

45:                                               ; preds = %top
  %46 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #13, !dbg !467
  %.not8 = icmp eq i32 %46, 0, !dbg !467
  br i1 %.not8, label %49, label %47, !dbg !467

47:                                               ; preds = %45
  %48 = call float @llvm.sqrt.f32(float %37) #12, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

49:                                               ; preds = %45
  %50 = call float @llvm.nvvm.sqrt.approx.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

__nv_sqrtf.exit:                                  ; preds = %49, %47, %43, %41
  %.0.i = phi float [ %42, %41 ], [ %44, %43 ], [ %48, %47 ], [ %50, %49 ], !dbg !467
  %51 = fmul float %32, %32, !dbg !468
  %52 = fmul float %26, %51, !dbg !471
  %53 = fmul float %52, %52, !dbg !472
  %54 = fmul float %52, %53, !dbg !472
  %55 = fmul float %.0.i, 2.400000e+01, !dbg !475
  %56 = fmul float %26, %55, !dbg !479
  %57 = fmul float %54, %54, !dbg !480
  %58 = fmul float %57, 2.000000e+00, !dbg !482
  %59 = fsub float %58, %54, !dbg !484
  %60 = fmul float %59, %56, !dbg !485
  %61 = fmul float %10, %60, !dbg !486
  %62 = fmul float %15, %60, !dbg !486
  %63 = fmul float %20, %60, !dbg !486
  %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 0, !dbg !492
  store float %61, float* %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx, align 4, !dbg !492
  %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 1, !dbg !492
  store float %62, float* %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6, align 4, !dbg !492
  %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 2, !dbg !492
  store float %63, float* %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7, align 4, !dbg !492
  ret void, !dbg !492
}

; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_force_4161([1 x [3 x float]]* noalias nocapture nofree noundef nonnull writeonly sret([1 x [3 x float]]) align 4 dereferenceable(12) %0, [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %1, [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %2, [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %3, [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %4) unnamed_addr #11 !dbg !439 {
top:
  %5 = call {}*** @julia.get_pgcstack() #12
  %6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 0, !dbg !440
  %7 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 0, !dbg !440
  %8 = load float, float addrspace(11)* %6, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %9 = load float, float addrspace(11)* %7, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %10 = fsub float %8, %9, !dbg !447
  %11 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 1, !dbg !440
  %12 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 1, !dbg !440
  %13 = load float, float addrspace(11)* %11, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %14 = load float, float addrspace(11)* %12, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %15 = fsub float %13, %14, !dbg !447
  %16 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 2, !dbg !440
  %17 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 2, !dbg !440
  %18 = load float, float addrspace(11)* %16, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %19 = load float, float addrspace(11)* %17, align 4, !dbg !447, !tbaa !67, !invariant.load !4
  %20 = fsub float %18, %19, !dbg !447
  %21 = fmul float %10, %10, !dbg !448
  %22 = fmul float %15, %15, !dbg !448
  %23 = fadd float %21, %22, !dbg !456
  %24 = fmul float %20, %20, !dbg !448
  %25 = fadd float %23, %24, !dbg !456
  %26 = fdiv float 1.000000e+00, %25, !dbg !457
  %27 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 0, !dbg !459
  %28 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 0, !dbg !459
  %29 = load float, float addrspace(11)* %27, align 4, !dbg !461, !tbaa !67, !invariant.load !4
  %30 = load float, float addrspace(11)* %28, align 4, !dbg !461, !tbaa !67, !invariant.load !4
  %31 = fadd float %29, %30, !dbg !461
  %32 = fmul float %31, 5.000000e-01, !dbg !462
  %33 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 1, !dbg !464
  %34 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 1, !dbg !464
  %35 = load float, float addrspace(11)* %33, align 4, !dbg !466, !tbaa !67, !invariant.load !4
  %36 = load float, float addrspace(11)* %34, align 4, !dbg !466, !tbaa !67, !invariant.load !4
  %37 = fmul float %35, %36, !dbg !466
  %38 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0)) #13, !dbg !467
  %.not = icmp eq i32 %38, 0, !dbg !467
  br i1 %.not, label %45, label %39, !dbg !467

39:                                               ; preds = %top
  %40 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #13, !dbg !467
  %.not9 = icmp eq i32 %40, 0, !dbg !467
  br i1 %.not9, label %43, label %41, !dbg !467

41:                                               ; preds = %39
  %42 = call float @llvm.nvvm.sqrt.rn.ftz.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

43:                                               ; preds = %39
  %44 = call float @llvm.nvvm.sqrt.approx.ftz.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

45:                                               ; preds = %top
  %46 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #13, !dbg !467
  %.not8 = icmp eq i32 %46, 0, !dbg !467
  br i1 %.not8, label %49, label %47, !dbg !467

47:                                               ; preds = %45
  %48 = call float @llvm.sqrt.f32(float %37) #12, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

49:                                               ; preds = %45
  %50 = call float @llvm.nvvm.sqrt.approx.f(float %37) #13, !dbg !467
  br label %__nv_sqrtf.exit, !dbg !467

__nv_sqrtf.exit:                                  ; preds = %49, %47, %43, %41
  %.0.i = phi float [ %42, %41 ], [ %44, %43 ], [ %48, %47 ], [ %50, %49 ], !dbg !467
  %51 = fmul float %32, %32, !dbg !468
  %52 = fmul float %26, %51, !dbg !471
  %53 = fmul float %52, %52, !dbg !472
  %54 = fmul float %52, %53, !dbg !472
  %55 = fmul float %.0.i, 2.400000e+01, !dbg !475
  %56 = fmul float %26, %55, !dbg !479
  %57 = fmul float %54, %54, !dbg !480
  %58 = fmul float %57, 2.000000e+00, !dbg !482
  %59 = fsub float %58, %54, !dbg !484
  %60 = fmul float %59, %56, !dbg !485
  %61 = fmul float %10, %60, !dbg !486
  %62 = fmul float %15, %60, !dbg !486
  %63 = fmul float %20, %60, !dbg !486
  %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 0, !dbg !492
  store float %61, float* %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx, align 4, !dbg !492
  %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 1, !dbg !492
  store float %62, float* %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6, align 4, !dbg !492
  %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 2, !dbg !492
  store float %63, float* %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7, align 4, !dbg !492
  ret void, !dbg !492
}

; Function Attrs: mustprogress willreturn
define internal fastcc void @diffejulia_force_4161([1 x [3 x float]]* noalias nocapture nofree noundef nonnull writeonly sret([1 x [3 x float]]) align 4 dereferenceable(12) %0, [1 x [3 x float]]* nocapture %"'", [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %1, [1 x [3 x float]] addrspace(11)* nocapture %"'1", [1 x [3 x float]] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(12) %2, [1 x [3 x float]] addrspace(11)* nocapture %"'2", [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %3, [2 x float] addrspace(11)* nocapture %"'3", [2 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(8) %4, [2 x float] addrspace(11)* nocapture %"'4") unnamed_addr #11 !dbg !493 {
top:
  %_replacementA = phi {}*** 
  %"'ipg46" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'2", i64 0, i64 0, i64 0, !dbg !494
  %5 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 0, !dbg !494
  %"'ipg45" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'1", i64 0, i64 0, i64 0, !dbg !494
  %6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 0, !dbg !494
  %7 = load float, float addrspace(11)* %5, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %8 = load float, float addrspace(11)* %6, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %9 = fsub float %7, %8, !dbg !501
  %"'ipg42" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'2", i64 0, i64 0, i64 1, !dbg !494
  %10 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 1, !dbg !494
  %"'ipg41" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'1", i64 0, i64 0, i64 1, !dbg !494
  %11 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 1, !dbg !494
  %12 = load float, float addrspace(11)* %10, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %13 = load float, float addrspace(11)* %11, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %14 = fsub float %12, %13, !dbg !501
  %"'ipg38" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'2", i64 0, i64 0, i64 2, !dbg !494
  %15 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %2, i64 0, i64 0, i64 2, !dbg !494
  %"'ipg37" = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %"'1", i64 0, i64 0, i64 2, !dbg !494
  %16 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]] addrspace(11)* %1, i64 0, i64 0, i64 2, !dbg !494
  %17 = load float, float addrspace(11)* %15, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %18 = load float, float addrspace(11)* %16, align 4, !dbg !501, !tbaa !67, !invariant.load !4
  %19 = fsub float %17, %18, !dbg !501
  %20 = fmul float %9, %9, !dbg !502
  %21 = fmul float %14, %14, !dbg !502
  %22 = fadd float %20, %21, !dbg !510
  %23 = fmul float %19, %19, !dbg !502
  %24 = fadd float %22, %23, !dbg !510
  %25 = fdiv float 1.000000e+00, %24, !dbg !511
  %"'ipg19" = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %"'3", i64 0, i64 0, !dbg !513
  %26 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 0, !dbg !513
  %"'ipg18" = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %"'4", i64 0, i64 0, !dbg !513
  %27 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 0, !dbg !513
  %28 = load float, float addrspace(11)* %26, align 4, !dbg !515, !tbaa !67, !invariant.load !4
  %29 = load float, float addrspace(11)* %27, align 4, !dbg !515, !tbaa !67, !invariant.load !4
  %30 = fadd float %28, %29, !dbg !515
  %31 = fmul float %30, 5.000000e-01, !dbg !516
  %"'ipg12" = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %"'3", i64 0, i64 1, !dbg !518
  %32 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %3, i64 0, i64 1, !dbg !518
  %"'ipg" = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %"'4", i64 0, i64 1, !dbg !518
  %33 = getelementptr inbounds [2 x float], [2 x float] addrspace(11)* %4, i64 0, i64 1, !dbg !518
  %34 = load float, float addrspace(11)* %32, align 4, !dbg !520, !tbaa !67, !invariant.load !4
  %35 = load float, float addrspace(11)* %33, align 4, !dbg !520, !tbaa !67, !invariant.load !4
  %36 = fmul float %34, %35, !dbg !520
  %37 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0)) #12, !dbg !521
  %.not = icmp eq i32 %37, 0, !dbg !521
  br i1 %.not, label %44, label %38, !dbg !521

38:                                               ; preds = %top
  %39 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #12, !dbg !521
  %.not9 = icmp eq i32 %39, 0, !dbg !521
  br i1 %.not9, label %42, label %40, !dbg !521

40:                                               ; preds = %38
  %41 = call float @llvm.nvvm.sqrt.rn.ftz.f(float %36) #12, !dbg !521
  br label %__nv_sqrtf.exit, !dbg !521

42:                                               ; preds = %38
  %43 = call float @llvm.nvvm.sqrt.approx.ftz.f(float %36) #12, !dbg !521
  br label %__nv_sqrtf.exit, !dbg !521

44:                                               ; preds = %top
  %45 = call i32 @__nvvm_reflect(i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @.str.2, i64 0, i64 0)) #12, !dbg !521
  %.not8 = icmp eq i32 %45, 0, !dbg !521
  br i1 %.not8, label %48, label %46, !dbg !521

46:                                               ; preds = %44
  %47 = call float @llvm.sqrt.f32(float %36) #13, !dbg !521
  br label %__nv_sqrtf.exit, !dbg !521

48:                                               ; preds = %44
  %49 = call float @llvm.nvvm.sqrt.approx.f(float %36) #12, !dbg !521
  br label %__nv_sqrtf.exit, !dbg !521

__nv_sqrtf.exit:                                  ; preds = %48, %46, %42, %40
  %.0.i = phi float [ %41, %40 ], [ %43, %42 ], [ %47, %46 ], [ %49, %48 ], !dbg !521
  %50 = fmul float %31, %31, !dbg !522
  %51 = fmul float %25, %50, !dbg !525
  %52 = fmul float %51, %51, !dbg !526
  %53 = fmul float %51, %52, !dbg !526
  %54 = fmul float %.0.i, 2.400000e+01, !dbg !529
  %55 = fmul float %25, %54, !dbg !533
  %56 = fmul float %53, %53, !dbg !534
  %57 = fmul float %56, 2.000000e+00, !dbg !536
  %58 = fsub float %57, %53, !dbg !538
  %59 = fmul float %58, %55, !dbg !539
  %60 = fmul float %9, %59, !dbg !540
  %61 = fmul float %14, %59, !dbg !540
  %62 = fmul float %19, %59, !dbg !540
  %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 0, !dbg !546
  store float %60, float* %.sroa.0.sroa.0.0..sroa.0.0..sroa_cast1.sroa_idx, align 4, !dbg !546
  %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 1, !dbg !546
  store float %61, float* %.sroa.0.sroa.2.0..sroa.0.0..sroa_cast1.sroa_idx6, align 4, !dbg !546
  %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7 = getelementptr inbounds [1 x [3 x float]], [1 x [3 x float]]* %0, i64 0, i64 0, i64 2, !dbg !546
  store float %62, float* %.sroa.0.sroa.3.0..sroa.0.0..sroa_cast1.sroa_idx7, align 4, !dbg !546
  br label %invert__nv_sqrtf.exit, !dbg !546

allocsForInversion:                               ; No predecessors!
  %"'de" = alloca float, align 4
  store float 0.000000e+00, float* %"'de", align 4
  %"'de10" = alloca float, align 4
  store float 0.000000e+00, float* %"'de10", align 4
  %"'de11" = alloca float, align 4
  store float 0.000000e+00, float* %"'de11", align 4
  %"'de13" = alloca float, align 4
  store float 0.000000e+00, float* %"'de13", align 4
  %"'de15" = alloca float, align 4
  store float 0.000000e+00, float* %"'de15", align 4
  %"'de16" = alloca float, align 4
  store float 0.000000e+00, float* %"'de16", align 4
  %"'de17" = alloca float, align 4
  store float 0.000000e+00, float* %"'de17", align 4
  %"'de20" = alloca float, align 4
  store float 0.000000e+00, float* %"'de20", align 4
  %"'de21" = alloca float, align 4
  store float 0.000000e+00, float* %"'de21", align 4
  %"'de22" = alloca float, align 4
  store float 0.000000e+00, float* %"'de22", align 4
  %"'de23" = alloca float, align 4
  store float 0.000000e+00, float* %"'de23", align 4
  %"'de26" = alloca float, align 4
  store float 0.000000e+00, float* %"'de26", align 4
  %"'de27" = alloca float, align 4
  store float 0.000000e+00, float* %"'de27", align 4
  %"'de28" = alloca float, align 4
  store float 0.000000e+00, float* %"'de28", align 4
  %"'de31" = alloca float, align 4
  store float 0.000000e+00, float* %"'de31", align 4
  %"'de34" = alloca float, align 4
  store float 0.000000e+00, float* %"'de34", align 4
  %"'de35" = alloca float, align 4
  store float 0.000000e+00, float* %"'de35", align 4
  %"'de36" = alloca float, align 4
  store float 0.000000e+00, float* %"'de36", align 4
  %"'de39" = alloca float, align 4
  store float 0.000000e+00, float* %"'de39", align 4
  %"'de40" = alloca float, align 4
  store float 0.000000e+00, float* %"'de40", align 4
  %"'de43" = alloca float, align 4
  store float 0.000000e+00, float* %"'de43", align 4
  %"'de44" = alloca float, align 4
  store float 0.000000e+00, float* %"'de44", align 4
  %"'de47" = alloca float, align 4
  store float 0.000000e+00, float* %"'de47", align 4

inverttop:                                        ; preds = %invert
  %63 = load float, float* %"'de", align 4
  %m0diffe = fmul fast float %63, %35
  %m1diffe = fmul fast float %63, %34
  store float 0.000000e+00, float* %"'de", align 4
  %64 = load float, float* %"'de10", align 4
  %65 = fadd fast float %64, %m0diffe
  store float %65, float* %"'de10", align 4
  %66 = load float, float* %"'de11", align 4
  %67 = fadd fast float %66, %m1diffe
  store float %67, float* %"'de11", align 4
  %68 = load float, float* %"'de11", align 4
  store float 0.000000e+00, float* %"'de11", align 4
  %69 = atomicrmw fadd float addrspace(11)* %"'ipg", float %68 monotonic, align 4
  %70 = load float, float* %"'de10", align 4
  store float 0.000000e+00, float* %"'de10", align 4
  %71 = atomicrmw fadd float addrspace(11)* %"'ipg12", float %70 monotonic, align 4
  %72 = load float, float* %"'de13", align 4
  %m0diffe14 = fmul fast float %72, 5.000000e-01
  store float 0.000000e+00, float* %"'de13", align 4
  %73 = load float, float* %"'de15", align 4
  %74 = fadd fast float %73, %m0diffe14
  store float %74, float* %"'de15", align 4
  %75 = load float, float* %"'de15", align 4
  store float 0.000000e+00, float* %"'de15", align 4
  %76 = load float, float* %"'de16", align 4
  %77 = fadd fast float %76, %75
  store float %77, float* %"'de16", align 4
  %78 = load float, float* %"'de17", align 4
  %79 = fadd fast float %78, %75
  store float %79, float* %"'de17", align 4
  %80 = load float, float* %"'de17", align 4
  store float 0.000000e+00, float* %"'de17", align 4
  %81 = atomicrmw fadd float addrspace(11)* %"'ipg18", float %80 monotonic, align 4
  %82 = load float, float* %"'de16", align 4
  store float 0.000000e+00, float* %"'de16", align 4
  %83 = atomicrmw fadd float addrspace(11)* %"'ipg19", float %82 monotonic, align 4
  %84 = load float, float* %"'de20", align 4
  %85 = fdiv fast float %84, %24
  %86 = fmul fast float %25, %85
  %87 = fneg fast float %86
  store float 0.000000e+00, float* %"'de20", align 4
  %88 = load float, float* %"'de21", align 4
  %89 = fadd fast float %88, %87
  store float %89, float* %"'de21", align 4
  %90 = load float, float* %"'de21", align 4
  store float 0.000000e+00, float* %"'de21", align 4
  %91 = load float, float* %"'de22", align 4
  %92 = fadd fast float %91, %90
  store float %92, float* %"'de22", align 4
  %93 = load float, float* %"'de23", align 4
  %94 = fadd fast float %93, %90
  store float %94, float* %"'de23", align 4
  %95 = load float, float* %"'de23", align 4
  %m0diffe24 = fmul fast float %95, %19
  %m1diffe25 = fmul fast float %95, %19
  store float 0.000000e+00, float* %"'de23", align 4
  %96 = load float, float* %"'de26", align 4
  %97 = fadd fast float %96, %m0diffe24
  store float %97, float* %"'de26", align 4
  %98 = load float, float* %"'de26", align 4
  %99 = fadd fast float %98, %m1diffe25
  store float %99, float* %"'de26", align 4
  %100 = load float, float* %"'de22", align 4
  store float 0.000000e+00, float* %"'de22", align 4
  %101 = load float, float* %"'de27", align 4
  %102 = fadd fast float %101, %100
  store float %102, float* %"'de27", align 4
  %103 = load float, float* %"'de28", align 4
  %104 = fadd fast float %103, %100
  store float %104, float* %"'de28", align 4
  %105 = load float, float* %"'de28", align 4
  %m0diffe29 = fmul fast float %105, %14
  %m1diffe30 = fmul fast float %105, %14
  store float 0.000000e+00, float* %"'de28", align 4
  %106 = load float, float* %"'de31", align 4
  %107 = fadd fast float %106, %m0diffe29
  store float %107, float* %"'de31", align 4
  %108 = load float, float* %"'de31", align 4
  %109 = fadd fast float %108, %m1diffe30
  store float %109, float* %"'de31", align 4
  %110 = load float, float* %"'de27", align 4
  %m0diffe32 = fmul fast float %110, %9
  %m1diffe33 = fmul fast float %110, %9
  store float 0.000000e+00, float* %"'de27", align 4
  %111 = load float, float* %"'de34", align 4
  %112 = fadd fast float %111, %m0diffe32
  store float %112, float* %"'de34", align 4
  %113 = load float, float* %"'de34", align 4
  %114 = fadd fast float %113, %m1diffe33
  store float %114, float* %"'de34", align 4
  %115 = load float, float* %"'de26", align 4
  %116 = fneg fast float %115
  store float 0.000000e+00, float* %"'de26", align 4
  %117 = load float, float* %"'de35", align 4
  %118 = fadd fast float %117, %115
  store float %118, float* %"'de35", align 4
  %119 = load float, float* %"'de36", align 4
  %120 = fadd fast float %119, %116
  store float %120, float* %"'de36", align 4
  %121 = load float, float* %"'de36", align 4
  store float 0.000000e+00, float* %"'de36", align 4
  %122 = atomicrmw fadd float addrspace(11)* %"'ipg37", float %121 monotonic, align 4
  %123 = load float, float* %"'de35", align 4
  store float 0.000000e+00, float* %"'de35", align 4
  %124 = atomicrmw fadd float addrspace(11)* %"'ipg38", float %123 monotonic, align 4
  %125 = load float, float* %"'de31", align 4
  %126 = fneg fast float %125
  store float 0.000000e+00, float* %"'de31", align 4
  %127 = load float, float* %"'de39", align 4
  %128 = fadd fast float %127, %125
  store float %128, float* %"'de39", align 4
  %129 = load float, float* %"'de40", align 4
  %130 = fadd fast float %129, %126
  store float %130, float* %"'de40", align 4
  %131 = load float, float* %"'de40", align 4
  store float 0.000000e+00, float* %"'de40", align 4
  %132 = atomicrmw fadd float addrspace(11)* %"'ipg41", float %131 monotonic, align 4
  %133 = load float, float* %"'de39", align 4
  store float 0.000000e+00, float* %"'de39", align 4
  %134 = atomicrmw fadd float addrspace(11)* %"'ipg42", float %133 monotonic, align 4
  %135 = load float, float* %"'de34", align 4
  %136 = fneg fast float %135
  store float 0.000000e+00, float* %"'de34", align 4
  %137 = load float, float* %"'de43", align 4
  %138 = fadd fast float %137, %135
  store float %138, float* %"'de43", align 4
  %139 = load float, float* %"'de44", align 4
  %140 = fadd fast float %139, %136
  store float %140, float* %"'de44", align 4
  %141 = load float, float* %"'de44", align 4
  store float 0.000000e+00, float* %"'de44", align 4
  %142 = atomicrmw fadd float addrspace(11)* %"'ipg45", float %141 monotonic, align 4
  %143 = load float, float* %"'de43", align 4
  store float 0.000000e+00, float* %"'de43", align 4
  %144 = atomicrmw fadd float addrspace(11)* %"'ipg46", float %143 monotonic, align 4
  ret void

invert:                                           ; No predecessors!
  br label %inverttop

invert5:                                          ; No predecessors!
  %145 = load float, float* %"'de47", align 4
  store float 0.000000e+00, float* %"'de47", align 4

invert6:                                          ; No predecessors!

invert7:                                          ; No predecessors!

invert8:                                          ; No predecessors!

invert9:                                          ; No predecessors!

invert__nv_sqrtf.exit:                            ; preds = %__nv_sqrtf.exit
}

cannot handle (reverse) unknown intrinsic
llvm.nvvm.sqrt.rn.ftz.f
  %42 = call float @llvm.nvvm.sqrt.rn.ftz.f(float %37) #13, !dbg !104

Stacktrace:
 [1] #sqrt
   @ ~/.julia/packages/CUDA/DfvRa/src/device/intrinsics/math.jl:220
 [2] force
   @ ~/dms/molly_dev/enzyme_err.jl:12

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/di3zM/src/compiler.jl:2636
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/di3zM/src/api.jl:111
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(kernel), Tuple{Duplicated{CuDeviceVector{SVector{3, Float32}, 1}}, Duplicated{CuDeviceVector{Atom{Float32}, 1}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool, returnPrimal::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/di3zM/src/compiler.jl:3271
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(grad_kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/di3zM/src/compiler.jl:4158
  [5] (::GPUCompiler.var"#114#117"{LLVM.Context, GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(grad_kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}, GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}})()
    @ GPUCompiler ~/.julia/packages/GPUCompiler/jVY4I/src/driver.jl:296
  [6] get!(default::GPUCompiler.var"#114#117"{LLVM.Context, GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(grad_kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}, GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}}, h::Dict{GPUCompiler.CompilerJob, String}, key::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}})
    @ Base ./dict.jl:481
  [7] macro expansion
    @ ~/.julia/packages/GPUCompiler/jVY4I/src/driver.jl:295 [inlined]
  [8] emit_llvm(job::GPUCompiler.CompilerJob, method_instance::Any; libraries::Bool, deferred_codegen::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, ctx::LLVM.Context)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/jVY4I/src/utils.jl:64
  [9] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:353
 [10] #224
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:347 [inlined]
 [11] JuliaContext(f::CUDA.var"#224#225"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(grad_kernel), Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/jVY4I/src/driver.jl:76
 [12] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:346
 [13] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/jVY4I/src/cache.jl:90
 [14] cufunction(f::typeof(grad_kernel), tt::Type{Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:299
 [15] cufunction(f::typeof(grad_kernel), tt::Type{Tuple{CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{SVector{3, Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}, CuDeviceVector{Atom{Float32}, 1}}})
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:292
 [16] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:102
wsmoses commented 2 years ago

Can you retry on main? This looks like a bug that was previously fixed.

jgreener64 commented 2 years ago

Right you are, fixed on main. Thanks for the quick response.