Open g-bauer opened 3 months ago
Even setting codegen-units=1 doesn't fix this, so I'll have to look into how rustc compiles libs.rs+main.rs here. For the performance, I'm deeply confused. To start, your function lowers to this IR, which looks reasonable:
18 ; enzyme_playground::_f1
17 ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind sanitize_hwaddress willreturn memory(none) uwtable
16 define noundef double @_ZN17enzyme_playground3_f117h5f4eabd9c1eee14aE(double noundef %x) unnamed_addr #1 {
15 start:
14 %0 = tail call double @llvm.exp.f64(double %x)
13 %1 = tail call double @llvm.sin.f64(double %x)
12 %2 = tail call double @llvm.powi.f64.i32(double %1, i32 3)
11 %3 = tail call double @llvm.cos.f64(double %x)
10 %4 = tail call double @llvm.powi.f64.i32(double %3, i32 3)
9 %_4.i = fadd double %2, %4
8 %5 = tail call double @llvm.sqrt.f64(double %_4.i)
7 %_0.i = fdiv double %0, %5
6 ret double %_0.i
5 }
I was also right in remembering that I have one more indirection, see this:
; Function Attrs: noinline nonlazybind uwtable
define internal { double, double } @_ZN17enzyme_playground4dfff17haf2f79cc33487406E(double noundef %0, double noundef %1) unnamed_addr #2 {
5006 %3 = call { double, double } @fwddiffe_f1(double %0, double %1)
ret { double, double } %3
}
But once you look inside the enzyme generated function, it calls one function per original instruction. I.e. we have 5 or 6 function calls beyond the extra indirection mentioned above. Of course, for such simple code that will completely kill the performance. Luckily enzyme has some debug flags described here: https://enzyme.mit.edu/index.fcgi/rust/Debugging.html
So now we run RUSTFLAGS="-Z autodiff=PrintModAfterEnzyme,Inline" cargo +enzyme build --release &> mod.ll
and get something much more sensible:
; Function Attrs: mustprogress noinline nonlazybind willreturn uwtable
define internal { double, double } @fwddiffe_f2(double noundef "enzyme_type"="{[-1]:Float@double}" %0, double "enzyme_type"="{[-1]:Float@double}" %1) unnamed_addr #122 personality ptr @rust_eh_personality {
call void @llvm.experimental.noalias.scope.decl(metadata !45833) #127
%3 = call noundef double @llvm.exp.f64(double %0) #127
%4 = call fast double @llvm.exp.f64(double %0)
%5 = fmul fast double %1, %4
call void @llvm.experimental.noalias.scope.decl(metadata !45836) #127
%6 = call noundef double @llvm.sin.f64(double %0) #127
%7 = call fast double @llvm.cos.f64(double %0)
%8 = fmul fast double %1, %7
call void @llvm.experimental.noalias.scope.decl(metadata !45839) #127
%9 = call noundef double @llvm.powi.f64.i32(double %6, i32 3) #127
%10 = fcmp fast oeq double %8, 0.000000e+00
%11 = and i1 false, %10
%12 = or i1 false, %11
%13 = call fast double @llvm.powi.f64.i32(double %6, i32 2)
%14 = fmul fast double 3.000000e+00, %13
%15 = fmul fast double %8, %14
%16 = select fast i1 %12, double 0.000000e+00, double %15
call void @llvm.experimental.noalias.scope.decl(metadata !45842) #127
%17 = call noundef double @llvm.cos.f64(double %0) #127
%18 = call fast double @llvm.sin.f64(double %0)
%19 = fneg fast double %18
%20 = fmul fast double %1, %19
call void @llvm.experimental.noalias.scope.decl(metadata !45845) #127
%21 = call noundef double @llvm.powi.f64.i32(double %17, i32 3) #127
%22 = fcmp fast oeq double %20, 0.000000e+00
%23 = and i1 false, %22
%24 = or i1 false, %23
%25 = call fast double @llvm.powi.f64.i32(double %17, i32 2)
%26 = fmul fast double 3.000000e+00, %25
%27 = fmul fast double %20, %26
%28 = select fast i1 %24, double 0.000000e+00, double %27
%29 = fadd double %9, %21
%30 = fadd fast double %16, %28
call void @llvm.experimental.noalias.scope.decl(metadata !45848) #127
%31 = call noundef double @llvm.sqrt.f64(double %29) #127
%32 = fcmp fast ueq double %29, 0.000000e+00
%33 = call fast double @llvm.sqrt.f64(double %29) #128
%34 = fmul fast double 2.000000e+00, %33
%35 = fdiv fast double %30, %34
%36 = select fast i1 %32, double 0.000000e+00, double %35
%37 = fdiv double %3, %31
%38 = fmul fast double %5, %31
%39 = fmul fast double %36, %3
%40 = fsub fast double %38, %39
%41 = fmul fast double %31, %31
%42 = fdiv fast double %40, %41
%43 = insertvalue { double, double } undef, double %37, 0
%44 = insertvalue { double, double } %43, double %42, 1
ret { double, double } %44
}
Benchmarking times aren't affected though, mine were better for enzyme from the beginning (415 instead of 500 which you have on your repo), but they are the same for me with and without the flag.
Enzyme: 1st order/forward
time: [415.55 ps 415.68 ps 415.86 ps]
num-dual: 1st order/Dual64
time: [308.20 ps 308.84 ps 309.84 ps]
Now, that we know that Enzyme's indirection is likely the issue, let's handycap num-dual and bench indirection:
+pub fn indirection<D: DualNum<f64>>(x: D) -> D {
+ f1(x)
+}
+
+#[inline(never)]
pub fn f1<D: DualNum<f64>>(x: D) -> D {
x.exp() / (x.sin().powi(3) + x.cos().powi(3)).sqrt()
}
Now we get:
num-dual: 1st order/Dual64
time: [408.21 ps 408.69 ps 409.33 ps]
change: [+30.526% +31.881% +33.197%] (p = 0.00 < 0.05)
Performance has regressed.
And indeed, it's down to Enzyme level. So in summary enzyme currently always has one more indirection because I didn't bother with cleaning up the llvm-ir enough. I never noticed because it's easily covered by every slightly more complex operation, but here we just have 5 simple operations, due to which it actually has an effect. I can't promise to fix it too soon since it likely won't be measurable beyond toy examples, but I'll leave it open as a reminder. It would also be an easy way to get started, it shouldn't be too hard for a new contributor.
I pushed https://github.com/EnzymeAD/rust/commit/ab5490415bb28ef6cdd8045d116a9951aa171c8e to remove one layer of indirection, but interesting enough it had no performance impact. I'll look if I can inline even the call to the differentiated function, that might help. In the meantime, please feel free to post the llvm-ir of your function, cargo has a flag for that. Maybe we can spot the difference that way?
@g-bauer Eventually this could be the reason why we're slower. There is a whole discussion on correctness in AD here: https://github.com/EnzymeAD/Enzyme/issues/1295 Do you have special handling in your tool for sqrt(0)? I tend to not adjust the default behaviour even if it has a small performance overhead, since for non-toy examples the perf benefits of LLVM based AD should easily cover this perf overhead. Do you have any larger benchmarks on which we could compare?
We don't have special treatment (see here). Taking the derivative of sqrt(0.0)
will return NaN
. But I don't think that's the issue here. Changing to a different operation (pow
, ln
, ...) doesn't change the results in the benchmark on my machine.
I'll add a longer example.
When importing a function (with
autodiff
macro) from a module, the derivatives are missing. I created a minimal example here.Meta
rustc --version --verbose
: