EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Crash with LoopVectorization #98

Closed oschulz closed 2 years ago

oschulz commented 3 years ago

Enzyme currently (current master branch) crashes with LoopVectorization. The last line in this example

using Enzyme
using LoopVectorization

function mymul_simd!(R, A, B)
    @assert axes(A,2) == axes(B,1) && axes(R,1) == axes(A,1) && axes(R,2) == axes(B,2)
    @inbounds @simd for i in eachindex(R)
        R[i] = 0
    end
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        @inbounds @simd for k in axes(A,2)
            R[i,j] += A[i,k] * B[k,j]
        end
    end
    nothing
end

function mymul_turbo!(R, A, B)
    @assert axes(A,2) == axes(B,1) && axes(R,1) == axes(A,1) && axes(R,2) == axes(B,2)
    @inbounds @turbo for i in eachindex(R)
        R[i] = 0
    end
    @inbounds @turbo for j in axes(B, 2), i in axes(A, 1), k in axes(A,2)
            R[i,j] += A[i,k] * B[k,j]
    end
    nothing
end

A = rand(500, 300)
B = rand(300, 700)
R = zeros(size(A,1), size(B,2))

@assert (fill!(R, NaN); mymul_simd!(R, A, B); R ≈ A * B)
@assert (fill!(R, NaN); mymul_turbo!(R, A, B); R ≈ A * B)

dA = similar(A); dB = similar(B); dR = similar(R)

fill!(R, NaN); fill!(dR, 1); fill!(dA, 0); fill!(dB, 0)
Enzyme.autodiff(mymul_simd!, Duplicated(R, dR), Duplicated(A, dA), Duplicated(B, dB))

fill!(R, NaN); fill!(dR, 1); fill!(dA, 0); fill!(dB, 0)
Enzyme.autodiff(mymul_turbo!, Duplicated(R, dR), Duplicated(A, dA), Duplicated(B, dB))

results in

  call:   %45 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %0), !dbg !77
 +   %53 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %52) #6, !dbg !240
 +   %3 = call {}*** @julia.ptls_states()
 +   %42 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %41) #6, !dbg !78
 +   call fastcc void @julia__turbo___6288(i64 signext %33, i64 signext %24, i64 signext %8, i64 zeroext %55, i64 zeroext %59, i64 zeroext %51, i64 signext %res.i5.i, i64 signext %res.i4.i, i64 signext %res.i6.i), !dbg !239
 +   %57 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %56) #6, !dbg !240
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:3553: void AdjointGenerator<AugmentedReturnType>::visitCallInst(llvm::CallInst&) [with AugmentedReturnType = const AugmentedReturn*]: Assertion `uncacheable_args_map.find(&call) != uncacheable_args_map.end()' failed.

signal (6): Aborted
wsmoses commented 3 years ago

Not fully resolved by, but immediate problem requires https://github.com/wsmoses/Enzyme/pull/247 to fix julia fooling LLVM intrinsic recognition, then https://github.com/wsmoses/Enzyme/pull/249 (a poor man's solution of masked load support, proper support marked as an issue here: https://github.com/wsmoses/Enzyme/issues/248).

This will also require completion of GC preserve/end support in this repo.

chriselrod commented 3 years ago

How well does it handle loops? In particular, I'm concerned about register use. I.e., LoopVectorization will do register tiling (i.e., unroll loops to store as many reused values inside registers as possible). This optimization is the key for performance on matmul and matmul-like loops.

My fear here would be by using just a few more registers in the loop, you could start spilling thereby leading to a substantial drop in performance. However, I'm guessing this won't be an issue with reverse mode.

Could I add primals for the SIMD special functions? Note that LLVM automatically scalarorizes @llvm.exp & co when called with vector arguments, so that's a no-go from a performance perspective.

wsmoses commented 2 years ago

So what do you mean by adding primals for special SIMD functions? Mind giving a code snippet?

Otherwise this issue is fully resolved properly after a bump in jll.

chriselrod commented 2 years ago

So what do you mean by adding primals for special SIMD functions? Mind giving a code snippet?

julia> using LoopVectorization

julia> function turbo_map!(f::F, y, x) where {F}
           @turbo for i in eachindex(y,x)
               y[i] = f(x[i])
           end
       end
turbo_map! (generic function with 1 method)

julia> code_llvm(turbo_map!, (typeof(exp), Vector{Float64}, Vector{Float64}), debuginfo=:none)

Produces this on a computer with AVX512:

define nonnull {}* @"japi1_turbo_map!_4126"({}* %0, {}** %1, i32 %2) #0 {
top:
  %3 = alloca [3 x {}*], align 8
  %gcframe200 = alloca [4 x {}*], align 16
  %gcframe200.sub = getelementptr inbounds [4 x {}*], [4 x {}*]* %gcframe200, i64 0, i64 0
  %.sub = getelementptr inbounds [3 x {}*], [3 x {}*]* %3, i64 0, i64 0
  %4 = bitcast [4 x {}*]* %gcframe200 to i8*
  call void @llvm.memset.p0i8.i32(i8* nonnull align 16 dereferenceable(32) %4, i8 0, i32 32, i1 false)
  %5 = alloca {}**, align 8
  store volatile {}** %1, {}*** %5, align 8
  %thread_ptr = call i8* asm "movq %fs:0, $0", "=r"() #10
  %ppgcstack_i8 = getelementptr i8, i8* %thread_ptr, i64 -8
  %ppgcstack = bitcast i8* %ppgcstack_i8 to {}****
  %pgcstack = load {}***, {}**** %ppgcstack, align 8
  %6 = bitcast [4 x {}*]* %gcframe200 to i64*
  store i64 8, i64* %6, align 16
  %7 = getelementptr inbounds [4 x {}*], [4 x {}*]* %gcframe200, i64 0, i64 1
  %8 = bitcast {}** %7 to {}***
  %9 = load {}**, {}*** %pgcstack, align 8
  store {}** %9, {}*** %8, align 8
  %10 = bitcast {}*** %pgcstack to {}***
  store {}** %gcframe200.sub, {}*** %10, align 8
  %11 = getelementptr inbounds {}*, {}** %1, i64 1
  %12 = load {}*, {}** %11, align 8
  %13 = getelementptr inbounds {}*, {}** %1, i64 2
  %14 = load {}*, {}** %13, align 8
  %15 = bitcast {}* %12 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 3
  %17 = bitcast {}** %16 to i64*
  %18 = load i64, i64* %17, align 8
  %19 = bitcast {}* %14 to {}**
  %20 = getelementptr inbounds {}*, {}** %19, i64 3
  %21 = bitcast {}** %20 to i64*
  %22 = load i64, i64* %21, align 8
  %.not = icmp eq i64 %18, %22
  br i1 %.not, label %L28, label %L11

L11:                                              ; preds = %top
  %ptls_field132 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2305843009213693954
  %23 = bitcast {}*** %ptls_field132 to i8**
  %ptls_load133134 = load i8*, i8** %23, align 8
  %24 = call noalias nonnull {}* @ijl_gc_pool_alloc(i8* %ptls_load133134, i32 1392, i32 16) #2
  %25 = bitcast {}* %24 to i64*
  %26 = getelementptr inbounds i64, i64* %25, i64 -1
  store atomic i64 140310388431728, i64* %26 unordered, align 8
  store i64 %18, i64* %25, align 8
  %ptls_load4137138 = load i8*, i8** %23, align 8
  %27 = getelementptr inbounds [4 x {}*], [4 x {}*]* %gcframe200, i64 0, i64 3
  store {}* %24, {}** %27, align 8
  %28 = call noalias nonnull {}* @ijl_gc_pool_alloc(i8* %ptls_load4137138, i32 1392, i32 16) #2
  %29 = bitcast {}* %28 to i64*
  %30 = getelementptr inbounds i64, i64* %29, i64 -1
  store atomic i64 140310388431728, i64* %30 unordered, align 8
  store i64 %22, i64* %29, align 8
  %31 = getelementptr inbounds [4 x {}*], [4 x {}*]* %gcframe200, i64 0, i64 2
  store {}* %28, {}** %31, align 16
  store {}* inttoptr (i64 140310387260320 to {}*), {}** %.sub, align 8
  %32 = getelementptr inbounds [3 x {}*], [3 x {}*]* %3, i64 0, i64 1
  store {}* %24, {}** %32, align 8
  %33 = getelementptr inbounds [3 x {}*], [3 x {}*]* %3, i64 0, i64 2
  store {}* %28, {}** %33, align 8
  %34 = call nonnull {}* @ijl_invoke({}* inttoptr (i64 140310409716512 to {}*), {}** nonnull %.sub, i32 3, {}* inttoptr (i64 140310409714752 to {}*))
  call void @llvm.trap()
  unreachable

L28:                                              ; preds = %top
  %35 = bitcast {}* %12 to i8**
  %36 = load i8*, i8** %35, align 8
  %37 = bitcast {}* %14 to i8**
  %38 = load i8*, i8** %37, align 8
  %39 = icmp ne i64 %18, 0
  call void @llvm.assume(i1 %39)
  %40 = and i64 %18, 9223372036854775792
  %.not127139 = icmp eq i64 %40, 0
  br i1 %.not127139, label %L92, label %L39.preheader

L39.preheader:                                    ; preds = %L28
  %ptr.0.i79 = bitcast i8* %38 to double*
  %ptr.0.i87 = bitcast i8* %36 to double*
  br label %L39

L39:                                              ; preds = %L39, %L39.preheader
  %value_phi140 = phi i64 [ %res.i82, %L39 ], [ 0, %L39.preheader ]
  %ptr.1.i80 = getelementptr inbounds double, double* %ptr.0.i79, i64 %value_phi140
  %ptr.1.i125 = bitcast double* %ptr.1.i80 to <8 x double>*
  %res.i126 = load <8 x double>, <8 x double>* %ptr.1.i125, align 8
  %ptr.1.i122 = getelementptr inbounds double, double* %ptr.1.i80, i64 8
  %ptr.2.i123 = bitcast double* %ptr.1.i122 to <8 x double>*
  %res.i124 = load <8 x double>, <8 x double>* %ptr.2.i123, align 8
  %res.i121 = fmul nsz contract <8 x double> %res.i126, <double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE>
  %res.i120 = fmul nsz contract <8 x double> %res.i124, <double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE>
  %res.i119 = call fast <8 x double> @llvm.nearbyint.v8f64(<8 x double> %res.i121)
  %res.i118 = call fast <8 x double> @llvm.nearbyint.v8f64(<8 x double> %res.i120)
  %res.i117 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i119, <8 x double> <double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF>, <8 x double> %res.i126)
  %res.i116 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i118, <8 x double> <double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF>, <8 x double> %res.i124)
  %res.i115 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i119, <8 x double> <double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F>, <8 x double> %res.i117)
  %res.i114 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i118, <8 x double> <double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F>, <8 x double> %res.i116)
  %res.i113 = fptosi <8 x double> %res.i119 to <8 x i64>
  %res.i112 = fptosi <8 x double> %res.i118 to <8 x i64>
  %res.i111 = and <8 x i64> %res.i113, <i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15>
  %res.i110 = and <8 x i64> %res.i112, <i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15>
  %res.i109 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i115, <8 x double> <double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD>, <8 x double> <double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF>)
  %res.i108 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i114, <8 x double> <double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD>, <8 x double> <double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF>)
  %res.i107 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i109, <8 x double> %res.i115, <8 x double> <double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D>)
  %res.i106 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i108, <8 x double> %res.i114, <8 x double> <double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D>)
  %res.i105 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i107, <8 x double> %res.i115, <8 x double> <double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378>)
  %res.i104 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i106, <8 x double> %res.i114, <8 x double> <double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378>)
  %res.i103 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i105, <8 x double> %res.i115, <8 x double> <double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004>)
  %res.i102 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i104, <8 x double> %res.i114, <8 x double> <double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004>)
  %res.i101 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i103, <8 x double> %res.i115, <8 x double> <double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003>)
  %res.i100 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i102, <8 x double> %res.i114, <8 x double> <double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003>)
  %res.i99 = fmul nsz contract <8 x double> %res.i115, %res.i101
  %res.i98 = fmul nsz contract <8 x double> %res.i114, %res.i100
  %res.i97 = call <8 x double> @llvm.x86.avx512.vpermi2var.pd.512(<8 x double> <double 1.000000e+00, double 0x3FF0B5586CF9890F, double 0x3FF172B83C7D517B, double 0x3FF2387A6E756238, double 0x3FF306FE0A31B715, double 0x3FF3DEA64C123422, double 0x3FF4BFDAD5362A27, double 0x3FF5AB07DD485429>, <8 x i64> %res.i111, <8 x double> <double 0x3FF6A09E667F3BCD, double 0x3FF7A11473EB0187, double 0x3FF8ACE5422AA0DB, double 0x3FF9C49182A3F090, double 0x3FFAE89F995AD3AD, double 0x3FFC199BDD85529C, double 0x3FFD5818DCFBA487, double 0x3FFEA4AFA2A490DA>)
  %res.i96 = call <8 x double> @llvm.x86.avx512.vpermi2var.pd.512(<8 x double> <double 1.000000e+00, double 0x3FF0B5586CF9890F, double 0x3FF172B83C7D517B, double 0x3FF2387A6E756238, double 0x3FF306FE0A31B715, double 0x3FF3DEA64C123422, double 0x3FF4BFDAD5362A27, double 0x3FF5AB07DD485429>, <8 x i64> %res.i110, <8 x double> <double 0x3FF6A09E667F3BCD, double 0x3FF7A11473EB0187, double 0x3FF8ACE5422AA0DB, double 0x3FF9C49182A3F090, double 0x3FFAE89F995AD3AD, double 0x3FFC199BDD85529C, double 0x3FFD5818DCFBA487, double 0x3FFEA4AFA2A490DA>)
  %res.i95 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i97, <8 x double> %res.i99, <8 x double> %res.i97)
  %res.i94 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i96, <8 x double> %res.i98, <8 x double> %res.i96)
  %res.i93 = fmul nsz contract <8 x double> %res.i119, <double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02>
  %res.i92 = fmul nsz contract <8 x double> %res.i118, <double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02>
  %res.i91 = call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> %res.i95, <8 x double> %res.i93, <8 x double> undef, i8 -1, i32 8)
  %res.i90 = call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> %res.i94, <8 x double> %res.i92, <8 x double> undef, i8 -1, i32 8)
  %ptr.1.i88 = getelementptr inbounds double, double* %ptr.0.i87, i64 %value_phi140
  %ptr.1.i86 = bitcast double* %ptr.1.i88 to <8 x double>*
  store <8 x double> %res.i91, <8 x double>* %ptr.1.i86, align 8
  %ptr.1.i84 = getelementptr inbounds double, double* %ptr.1.i88, i64 8
  %ptr.2.i85 = bitcast double* %ptr.1.i84 to <8 x double>*
  store <8 x double> %res.i90, <8 x double>* %ptr.2.i85, align 8
  %res.i82 = add nuw nsw i64 %value_phi140, 16
  %.not127 = icmp eq i64 %res.i82, %40
  br i1 %.not127, label %L92, label %L39

L92:                                              ; preds = %L39, %L28
  %.not128 = icmp ult i64 %40, %18
  br i1 %.not128, label %L94, label %L192

L94:                                              ; preds = %L92
  %41 = trunc i64 %18 to i32
  %42 = and i32 %41, 7
  %.not129 = icmp eq i32 %42, 0
  %43 = select i1 %.not129, i32 8, i32 %42
  %res.i = call i32 @llvm.x86.bmi.bzhi.32(i32 -1, i32 %43)
  %44 = trunc i32 %res.i to i8
  %res.i78 = add nsw i64 %18, -9
  %.not130 = icmp slt i64 %res.i78, %40
  br i1 %.not130, label %L105, label %L139

L105:                                             ; preds = %L94
  %res.i77 = shl nsw i64 %40, 3
  %ptr.1.i73 = getelementptr inbounds i8, i8* %38, i64 %res.i77
  %ptr.2.i74 = bitcast i8* %ptr.1.i73 to <8 x double>*
  %mask.0.i75 = bitcast i8 %44 to <8 x i1>
  %res.i76 = call <8 x double> @llvm.masked.load.v8f64.p0v8f64(<8 x double>* nonnull %ptr.2.i74, i32 8, <8 x i1> %mask.0.i75, <8 x double> zeroinitializer)
  %res.i72 = fmul nsz contract <8 x double> %res.i76, <double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE>
  %res.i71 = call fast <8 x double> @llvm.nearbyint.v8f64(<8 x double> %res.i72)
  %res.i70 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i71, <8 x double> <double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF>, <8 x double> %res.i76)
  %res.i69 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i71, <8 x double> <double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F>, <8 x double> %res.i70)
  %res.i68 = fptosi <8 x double> %res.i71 to <8 x i64>
  %res.i67 = and <8 x i64> %res.i68, <i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15>
  %res.i66 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i69, <8 x double> <double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD>, <8 x double> <double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF>)
  %res.i65 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i66, <8 x double> %res.i69, <8 x double> <double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D>)
  %res.i64 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i65, <8 x double> %res.i69, <8 x double> <double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378>)
  %res.i63 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i64, <8 x double> %res.i69, <8 x double> <double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004>)
  %res.i62 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i63, <8 x double> %res.i69, <8 x double> <double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003>)
  %res.i61 = fmul nsz contract <8 x double> %res.i69, %res.i62
  %res.i60 = call <8 x double> @llvm.x86.avx512.vpermi2var.pd.512(<8 x double> <double 1.000000e+00, double 0x3FF0B5586CF9890F, double 0x3FF172B83C7D517B, double 0x3FF2387A6E756238, double 0x3FF306FE0A31B715, double 0x3FF3DEA64C123422, double 0x3FF4BFDAD5362A27, double 0x3FF5AB07DD485429>, <8 x i64> %res.i67, <8 x double> <double 0x3FF6A09E667F3BCD, double 0x3FF7A11473EB0187, double 0x3FF8ACE5422AA0DB, double 0x3FF9C49182A3F090, double 0x3FFAE89F995AD3AD, double 0x3FFC199BDD85529C, double 0x3FFD5818DCFBA487, double 0x3FFEA4AFA2A490DA>)
  %res.i59 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i60, <8 x double> %res.i61, <8 x double> %res.i60)
  %res.i58 = fmul nsz contract <8 x double> %res.i71, <double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02>
  %res.i57 = call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> %res.i59, <8 x double> %res.i58, <8 x double> undef, i8 -1, i32 8)
  %ptr.1.i53 = getelementptr inbounds i8, i8* %36, i64 %res.i77
  %ptr.2.i54 = bitcast i8* %ptr.1.i53 to <8 x double>*
  call void @llvm.masked.store.v8f64.p0v8f64(<8 x double> %res.i57, <8 x double>* nonnull %ptr.2.i54, i32 8, <8 x i1> %mask.0.i75)
  br label %L192

L139:                                             ; preds = %L94
  %ptr.0.i49 = bitcast i8* %38 to double*
  %ptr.1.i50 = getelementptr inbounds double, double* %ptr.0.i49, i64 %40
  %ptr.1.i47 = bitcast double* %ptr.1.i50 to <8 x double>*
  %res.i48 = load <8 x double>, <8 x double>* %ptr.1.i47, align 8
  %ptr.1.i43 = getelementptr inbounds double, double* %ptr.1.i50, i64 8
  %ptr.2.i44 = bitcast double* %ptr.1.i43 to <8 x double>*
  %mask.0.i45 = bitcast i8 %44 to <8 x i1>
  %res.i46 = call <8 x double> @llvm.masked.load.v8f64.p0v8f64(<8 x double>* nonnull %ptr.2.i44, i32 8, <8 x i1> %mask.0.i45, <8 x double> zeroinitializer)
  %res.i41 = fmul nsz contract <8 x double> %res.i48, <double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE>
  %res.i40 = fmul nsz contract <8 x double> %res.i46, <double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE, double 0x40371547652B82FE>
  %res.i39 = call fast <8 x double> @llvm.nearbyint.v8f64(<8 x double> %res.i41)
  %res.i38 = call fast <8 x double> @llvm.nearbyint.v8f64(<8 x double> %res.i40)
  %res.i37 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i39, <8 x double> <double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF>, <8 x double> %res.i48)
  %res.i36 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i38, <8 x double> <double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF, double 0xBFA62E42FEFA39EF>, <8 x double> %res.i46)
  %res.i35 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i39, <8 x double> <double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F>, <8 x double> %res.i37)
  %res.i34 = call nsz contract <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i38, <8 x double> <double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F, double 0xBC3ABC9E3B39803F>, <8 x double> %res.i36)
  %res.i33 = fptosi <8 x double> %res.i39 to <8 x i64>
  %res.i32 = fptosi <8 x double> %res.i38 to <8 x i64>
  %res.i31 = and <8 x i64> %res.i33, <i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15>
  %res.i30 = and <8 x i64> %res.i32, <i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15, i64 15>
  %res.i29 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i35, <8 x double> <double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD>, <8 x double> <double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF>)
  %res.i28 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i34, <8 x double> <double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD, double 0x3F56C1851427EDAD>, <8 x double> <double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF, double 0x3F811123CF1E04EF>)
  %res.i27 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i29, <8 x double> %res.i35, <8 x double> <double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D>)
  %res.i26 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i28, <8 x double> %res.i34, <8 x double> <double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D, double 0x3FA555555547135D>)
  %res.i25 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i27, <8 x double> %res.i35, <8 x double> <double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378>)
  %res.i24 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i26, <8 x double> %res.i34, <8 x double> <double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378, double 0x3FC555555547D378>)
  %res.i23 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i25, <8 x double> %res.i35, <8 x double> <double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004>)
  %res.i22 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i24, <8 x double> %res.i34, <8 x double> <double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004, double 0x3FE0000000000004>)
  %res.i21 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i23, <8 x double> %res.i35, <8 x double> <double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003>)
  %res.i20 = call reassoc nsz arcp contract afn <8 x double> @llvm.fmuladd.v8f64(<8 x double> %res.i22, <8 x double> %res.i34, <8 x double> <double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003, double 0x3FF0000000000003>)
  %res.i19 = fmul nsz contract <8 x double> %res.i35, %res.i21
  %res.i18 = fmul nsz contract <8 x double> %res.i34, %res.i20
  %res.i17 = call <8 x double> @llvm.x86.avx512.vpermi2var.pd.512(<8 x double> <double 1.000000e+00, double 0x3FF0B5586CF9890F, double 0x3FF172B83C7D517B, double 0x3FF2387A6E756238, double 0x3FF306FE0A31B715, double 0x3FF3DEA64C123422, double 0x3FF4BFDAD5362A27, double 0x3FF5AB07DD485429>, <8 x i64> %res.i31, <8 x double> <double 0x3FF6A09E667F3BCD, double 0x3FF7A11473EB0187, double 0x3FF8ACE5422AA0DB, double 0x3FF9C49182A3F090, double 0x3FFAE89F995AD3AD, double 0x3FFC199BDD85529C, double 0x3FFD5818DCFBA487, double 0x3FFEA4AFA2A490DA>)
  %res.i16 = call <8 x double> @llvm.x86.avx512.vpermi2var.pd.512(<8 x double> <double 1.000000e+00, double 0x3FF0B5586CF9890F, double 0x3FF172B83C7D517B, double 0x3FF2387A6E756238, double 0x3FF306FE0A31B715, double 0x3FF3DEA64C123422, double 0x3FF4BFDAD5362A27, double 0x3FF5AB07DD485429>, <8 x i64> %res.i30, <8 x double> <double 0x3FF6A09E667F3BCD, double 0x3FF7A11473EB0187, double 0x3FF8ACE5422AA0DB, double 0x3FF9C49182A3F090, double 0x3FFAE89F995AD3AD, double 0x3FFC199BDD85529C, double 0x3FFD5818DCFBA487, double 0x3FFEA4AFA2A490DA>)
  %res.i15 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i17, <8 x double> %res.i19, <8 x double> %res.i17)
  %res.i14 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i16, <8 x double> %res.i18, <8 x double> %res.i16)
  %res.i13 = fmul nsz contract <8 x double> %res.i39, <double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02>
  %res.i12 = fmul nsz contract <8 x double> %res.i38, <double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02, double 6.250000e-02>
  %res.i11 = call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> %res.i15, <8 x double> %res.i13, <8 x double> undef, i8 -1, i32 8)
  %res.i10 = call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> %res.i14, <8 x double> %res.i12, <8 x double> undef, i8 -1, i32 8)
  %ptr.0.i7 = bitcast i8* %36 to double*
  %ptr.1.i8 = getelementptr inbounds double, double* %ptr.0.i7, i64 %40
  %ptr.1.i6 = bitcast double* %ptr.1.i8 to <8 x double>*
  store <8 x double> %res.i11, <8 x double>* %ptr.1.i6, align 8
  %ptr.1.i = getelementptr inbounds double, double* %ptr.1.i8, i64 8
  %ptr.2.i = bitcast double* %ptr.1.i to <8 x double>*
  call void @llvm.masked.store.v8f64.p0v8f64(<8 x double> %res.i10, <8 x double>* nonnull %ptr.2.i, i32 8, <8 x i1> %mask.0.i45)
  br label %L192

L192:                                             ; preds = %L139, %L105, %L92
  %45 = load {}*, {}** %7, align 8
  %46 = bitcast {}*** %pgcstack to {}**
  store {}* %45, {}** %46, align 8
  ret {}* inttoptr (i64 140310641713160 to {}*)
}

which isn't really amenable to autodiff, even though exp itself is trivial.

The actual information of exp is lost above, so given that this is what LLVM sees, it doesn't have a chance.

We talked at JuliaCon about possibly being able to hook in somehow. Also, about handling masks, which I haven't gotten around to...

BTW, I'm interested in your polygeist work and the affine dialect.

LoopVectorization is doing a ton of things wrong, so I've been slowly working on rewriting it and figuring out how to address what I see as the major issues. However, IMO, it is doing a less wrong than the other libraries I've looked at, in that it at least regularly results in substantial performance improvements in many benchmarks. I don't have any real C++, LLVM, or MLIR background, but it may be appealing to just focus on the loop modeling and transformations part. I don't think it makes sense to do the loop modeling at a part of the compilation pipeline different from where vectorization takes place. So it was odd to see a comment that there was a bug polygeist fixed where MLIR transforms blocked LLVM's vectorizer -- you shouldn't be using it.

vchuravy commented 2 years ago

fixed on Enzyme 0.8