EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Slow Broadcasting (compared to Zygote) #1434

Open avik-pal opened 1 month ago

avik-pal commented 1 month ago
using NNlib, Enzyme, Zygote

gelu_act(x) = sum(abs2, gelu.(x))

x = randn(Float32, 32, 32, 1, 32)

@btime Enzyme.gradient(Reverse, $gelu_act, $x); # 1.298 ms (65 allocations: 386.95 KiB)

@btime Zygote.gradient($gelu_act, $x); # 735.499 μs (26 allocations: 384.75 KiB)

This might be somewhat unfair because gelu has rrule defined.

wsmoses commented 1 month ago


using BenchmarkTools, Enzyme, Zygote

gelu_act(x) = sum(abs2, sin.(x))

x = randn(Float32, 32, 32, 1, 32); 

Enzyme.gradient(Reverse, gelu_act, x); 

@btime Enzyme.gradient(Reverse, $gelu_act, $x); 

@btime Zygote.gradient($gelu_act, $x)
wsmoses commented 1 month ago

So the code generated here by the broadacst, before any Enzyme AD is actually quite awful. I wrote some primitive optimization passes to do a bit of cleanup (which may fix the runtiem activity), but still the indexing pattern is really bad.

@vchuravy any ideas what's happening here (besides it presumably now being > 3 dims so no specialization by Julia)

after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@float}" float @preprocess_julia_gelu_act_1468({} addrspace(10)* nocapture noundef nonnull readonly align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="131622087272080" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #21 !dbg !1226 {
top:
  %newstruct14 = alloca [4 x [1 x i64]], align 8
  %newstruct133 = alloca [4 x [1 x i64]], align 8
  %1 = call {}*** @julia.get_pgcstack() #22
  %current_task1161 = getelementptr inbounds {}**, {}*** %1, i64 -14
  %current_task1 = bitcast {}*** %current_task1161 to {}**
  %ptls_field162 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field162 to i64***
  %ptls_load163164 = load i64**, i64*** %2, align 8, !tbaa !19
  %3 = getelementptr inbounds i64*, i64** %ptls_load163164, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !23
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #22, !dbg !1227
  fence syncscope("singlethread") seq_cst
  %4 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !1228
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 3, !dbg !1228
  %5 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !1228
  %arraysize = load i64, i64 addrspace(11)* %5, align 8, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr2 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 4, !dbg !1228
  %6 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr2 to i64 addrspace(11)*, !dbg !1228
  %arraysize3 = load i64, i64 addrspace(11)* %6, align 16, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr4 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 5, !dbg !1228
  %7 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr4 to i64 addrspace(11)*, !dbg !1228
  %arraysize5 = load i64, i64 addrspace(11)* %7, align 8, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr6 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 6, !dbg !1228
  %8 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr6 to i64 addrspace(11)*, !dbg !1228
  %arraysize7 = load i64, i64 addrspace(11)* %8, align 16, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %9 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 1, i64 0, !dbg !1237
  store i64 %arraysize3, i64* %9, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %10 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 2, i64 0, !dbg !1237
  store i64 %arraysize5, i64* %10, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %11 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 3, i64 0, !dbg !1237
  store i64 %arraysize7, i64* %11, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %memcpy_refined_dst = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 0, i64 0, !dbg !1241
  store i64 %arraysize, i64* %memcpy_refined_dst, align 8, !dbg !1241, !tbaa !149, !alias.scope !151, !noalias !1242
  %box = call noalias nonnull dereferenceable(32) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 32, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 131621933207328 to {}*) to {} addrspace(10)*)) #23, !dbg !1245
  %12 = bitcast {} addrspace(10)* %box to i8 addrspace(10)*, !dbg !1245
  %newstruct15.sroa.0.0..sroa_cast = bitcast {} addrspace(10)* %box to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize, i64 addrspace(10)* %newstruct15.sroa.0.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.2.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 8, !dbg !1245
  %newstruct15.sroa.2.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.2.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize3, i64 addrspace(10)* %newstruct15.sroa.2.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.3.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 16, !dbg !1245
  %newstruct15.sroa.3.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.3.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize5, i64 addrspace(10)* %newstruct15.sroa.3.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.4.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 24, !dbg !1245
  %newstruct15.sroa.4.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.4.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize7, i64 addrspace(10)* %newstruct15.sroa.4.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %13 = call noalias nonnull {} addrspace(10)* @ijl_new_array({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 131622087272080 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %box) #24, !dbg !1245
  %14 = addrspacecast {} addrspace(10)* %13 to {} addrspace(10)* addrspace(11)*, !dbg !1253
  %arraysize_ptr17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 3, !dbg !1253
  %15 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr17 to i64 addrspace(11)*, !dbg !1253
  %arraysize18 = load i64, i64 addrspace(11)* %15, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr19 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 4, !dbg !1253
  %16 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr19 to i64 addrspace(11)*, !dbg !1253
  %arraysize20 = load i64, i64 addrspace(11)* %16, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr21 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 5, !dbg !1253
  %17 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr21 to i64 addrspace(11)*, !dbg !1253
  %arraysize22 = load i64, i64 addrspace(11)* %17, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr23 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 6, !dbg !1253
  %18 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr23 to i64 addrspace(11)*, !dbg !1253
  %arraysize24 = load i64, i64 addrspace(11)* %18, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %.not = icmp ne i64 %arraysize18, %arraysize, !dbg !1261
  %.not171 = icmp ne i64 %arraysize20, %arraysize3
  %or.cond = select i1 %.not, i1 true, i1 %.not171, !dbg !1265
  %.not172 = icmp ne i64 %arraysize22, %arraysize5
  %or.cond173 = select i1 %or.cond, i1 true, i1 %.not172, !dbg !1265
  %19 = icmp ne i64 %arraysize24, %arraysize7
  %or.cond176 = select i1 %or.cond173, i1 true, i1 %19, !dbg !1265
  br i1 %or.cond176, label %L250, label %L58, !dbg !1265

L58:                                              ; preds = %top
  %20 = icmp eq i64 %arraysize7, 1, !dbg !1266
  %21 = icmp eq i64 %arraysize5, 1, !dbg !1279
  %22 = icmp eq i64 %arraysize3, 1, !dbg !1282
  %23 = icmp eq i64 %arraysize, 1, !dbg !1285
  %24 = icmp ne i64 %arraysize3, 0, !dbg !1288
  %25 = icmp ne i64 %arraysize5, 0, !dbg !1294
  %26 = icmp ne i64 %arraysize7, 0, !dbg !1294
  %27 = and i1 %24, %25, !dbg !1297
  %28 = and i1 %27, %26, !dbg !1297
  br i1 %28, label %L126.preheader, label %L272, !dbg !1292

L126.preheader:                                   ; preds = %L58
  %.not166 = icmp eq i64 %arraysize, 0
  %29 = addrspacecast {} addrspace(10)* %0 to float addrspace(13)* addrspace(11)*
  %30 = addrspacecast {} addrspace(10)* %13 to float addrspace(13)* addrspace(11)*
  br label %L126.outer, !dbg !1299

L126.outer:                                       ; preds = %L226, %L126.preheader
  %iv = phi i64 [ %iv.next, %L226 ], [ 0, %L126.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1
  %value_phi56.op = add nsw i64 %iv.next, -1
  %31 = select i1 %20, i64 0, i64 %value_phi56.op
  %32 = mul i64 %31, %arraysize5
  %arrayptr168 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %29, align 16
  %33 = mul i64 %value_phi56.op, %arraysize5
  br label %L126, !dbg !1299

L126:                                             ; preds = %.thread, %L126.outer
  %iv1 = phi i64 [ %iv.next2, %.thread ], [ 0, %L126.outer ]
  %value_phi54 = phi i64 [ %value_phi92.ph, %.thread ], [ 1, %L126.outer ]
  %value_phi55 = phi i64 [ %value_phi93.ph, %.thread ], [ 1, %L126.outer ]
  %iv.next2 = add nuw nsw i64 %iv1, 1, !dbg !1299
  br i1 %.not166, label %L192, label %L141.lr.ph, !dbg !1299

L141.lr.ph:                                       ; preds = %L126
  %value_phi54.op = add nsw i64 %value_phi54, -1
  %34 = select i1 %22, i64 0, i64 %value_phi54.op
  %value_phi55.op = add i64 %value_phi55, -1
  %35 = select i1 %21, i64 0, i64 %value_phi55.op
  %reass.add = add i64 %35, %32
  %reass.mul = mul i64 %reass.add, %arraysize3
  %reass.add201 = add i64 %reass.mul, %34
  %reass.mul202 = mul i64 %reass.add201, %arraysize
  %reass.add199 = add i64 %value_phi55.op, %33
  %reass.mul200 = mul i64 %reass.add199, %arraysize3
  %reass.add203 = add i64 %reass.mul200, %value_phi54.op
  %reass.mul204 = mul i64 %reass.add203, %arraysize
  br label %L141, !dbg !1300

L141:                                             ; preds = %L141, %L141.lr.ph
  %iv3 = phi i64 [ %iv.next4, %L141 ], [ 0, %L141.lr.ph ]
  %iv.next4 = add nuw nsw i64 %iv3, 1, !dbg !1301
  %36 = select i1 %23, i64 0, i64 %iv3, !dbg !1304
  %37 = add i64 %36, %reass.mul202, !dbg !1304
  %38 = getelementptr inbounds float, float addrspace(13)* %arrayptr168, i64 %37, !dbg !1304
  %arrayref = load float, float addrspace(13)* %38, align 4, !dbg !1304, !tbaa !77, !alias.scope !80, !noalias !81
  %39 = call float @julia_sin_1489(float %arrayref) #22, !dbg !1312
  %40 = add i64 %iv3, %reass.mul204, !dbg !1314
  %arrayptr88169 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 8, !dbg !1314, !tbaa !23, !alias.scope !1316, !noalias !48, !nonnull !18
  %41 = getelementptr inbounds float, float addrspace(13)* %arrayptr88169, i64 %40, !dbg !1314
  store float %39, float addrspace(13)* %41, align 4, !dbg !1314, !tbaa !77, !alias.scope !80, !noalias !1317
  %exitcond.not = icmp eq i64 %iv.next4, %arraysize, !dbg !1318
  br i1 %exitcond.not, label %L192.loopexit, label %L141, !dbg !1300, !llvm.loop !1319

L192.loopexit:                                    ; preds = %L141
  br label %L192, !dbg !1320

L192:                                             ; preds = %L192.loopexit, %L126
  %42 = add i64 %value_phi54, 1, !dbg !1320
  %43 = icmp ugt i64 %value_phi54, 9223372036854775806, !dbg !1324
  %44 = icmp sgt i64 %42, %arraysize3, !dbg !1324
  %45 = or i1 %43, %44, !dbg !1327
  %46 = icmp eq i64 %value_phi54, %arraysize3
  %or.cond174 = or i1 %46, %45, !dbg !1327
  br i1 %or.cond174, label %L201, label %.thread, !dbg !1327

L201:                                             ; preds = %L192
  %47 = add i64 %value_phi55, 1, !dbg !1328
  %48 = icmp ugt i64 %value_phi55, 9223372036854775806, !dbg !1331
  %49 = icmp sgt i64 %47, %arraysize5, !dbg !1331
  %50 = or i1 %48, %49, !dbg !1334
  %51 = icmp eq i64 %value_phi55, %arraysize5
  %or.cond175 = or i1 %51, %50, !dbg !1334
  br i1 %or.cond175, label %L226, label %.thread, !dbg !1334

.thread:                                          ; preds = %L201, %L192
  %value_phi92.ph = phi i64 [ 1, %L201 ], [ %42, %L192 ]
  %value_phi93.ph = phi i64 [ %47, %L201 ], [ %value_phi55, %L192 ]
  br label %L126, !dbg !1323

L226:                                             ; preds = %L201
  %52 = add nuw nsw i64 %iv.next, 1, !dbg !1335
  %exitcond207.not = icmp eq i64 %iv.next, %arraysize7, !dbg !1338
  br i1 %exitcond207.not, label %L272.loopexit, label %L126.outer, !dbg !1323

L250:                                             ; preds = %top
  %53 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 0, i64 0, !dbg !1339
  store i64 %arraysize18, i64* %53, align 8, !dbg !1339, !tbaa !149, !alias.scope !151, !noalias !1242
  %54 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 1, i64 0, !dbg !1343
  store i64 %arraysize20, i64* %54, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %55 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 2, i64 0, !dbg !1343
  store i64 %arraysize22, i64* %55, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %56 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 3, i64 0, !dbg !1343
  store i64 %arraysize24, i64* %56, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %57 = addrspacecast [4 x [1 x i64]]* %newstruct133 to [4 x [1 x i64]] addrspace(11)*, !dbg !1259
  %58 = addrspacecast [4 x [1 x i64]]* %newstruct14 to [4 x [1 x i64]] addrspace(11)*, !dbg !1259
  call fastcc void @julia_throwdm_1475([4 x [1 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %57, [4 x [1 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %58) #25, !dbg !1259
  unreachable, !dbg !1259

L272.loopexit:                                    ; preds = %L226
  br label %L272, !dbg !1347

L272:                                             ; preds = %L272.loopexit, %L58
  %59 = call fastcc float @julia__mapreduce_1477({} addrspace(10)* noalias nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %13) #22, !dbg !1347
  ret float %59, !dbg !1347
}
wsmoses commented 1 month ago

This is post my fix btw^

Post fix timings:


julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  927.718 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  386.687 μs (38 allocations: 641.19 KiB)

The bigger issue rn imo is the fact that loop bounds aren't statically inferrable due to whatever that awful index math is. So as a result inner loops are caching the true iteration count, inside other loops, doing a bunch of unnecessary caching/etc. That's still not fixed.

Pre my fix timings are slower so minor fix does ~something for perf, but again index math is likely root cause. At least others won't have to deal with runtime activity though.

julia> Enzyme.gradient(Reverse, gelu_act, x);

julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  967.698 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  377.638 μs (38 allocations: 641.19 KiB)
wsmoses commented 1 month ago

Module pre optimization: preopt.ll.txt