EnzymeAD / rust

A rust fork to work towards Enzyme integration
https://www.rust-lang.org
Other
71 stars 8 forks source link

Derivatives of function imported from module not working properly #173

Open g-bauer opened 3 months ago

g-bauer commented 3 months ago

When importing a function (with autodiff macro) from a module, the derivatives are missing. I created a minimal example here.

# identical function as below, but defined in lib.rs
[src/main.rs:34:5] enzyme_y1_lib = (
    4.497780053946161,
    0.0, # <---
)
# identical function as above, but defined in main.rs
[src/main.rs:35:5] enzyme_y1f = (
    4.497780053946161,
    4.05342789389862,
)

Meta

rustc --version --verbose:

rustc 1.82.0-nightly (86dedf7dc 2024-08-16)
binary: rustc
commit-hash: 86dedf7dc5b63661998a038c726033ad92c2d40e
commit-date: 2024-08-16
host: x86_64-unknown-linux-gnu
release: 1.82.0-nightly
LLVM version: 19.1.0
ZuseZ4 commented 3 months ago

Even setting codegen-units=1 doesn't fix this, so I'll have to look into how rustc compiles libs.rs+main.rs here. For the performance, I'm deeply confused. To start, your function lowers to this IR, which looks reasonable:

   18 ; enzyme_playground::_f1
   17 ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind sanitize_hwaddress willreturn memory(none) uwtable
   16 define noundef double @_ZN17enzyme_playground3_f117h5f4eabd9c1eee14aE(double noundef %x) unnamed_addr #1 {
   15 start:
   14   %0 = tail call double @llvm.exp.f64(double %x)
   13   %1 = tail call double @llvm.sin.f64(double %x)
   12   %2 = tail call double @llvm.powi.f64.i32(double %1, i32 3)
   11   %3 = tail call double @llvm.cos.f64(double %x)
   10   %4 = tail call double @llvm.powi.f64.i32(double %3, i32 3)
    9   %_4.i = fadd double %2, %4
    8   %5 = tail call double @llvm.sqrt.f64(double %_4.i)
    7   %_0.i = fdiv double %0, %5
    6   ret double %_0.i
    5 }

I was also right in remembering that I have one more indirection, see this:

; Function Attrs: noinline nonlazybind uwtable
 define internal { double, double } @_ZN17enzyme_playground4dfff17haf2f79cc33487406E(double noundef %0, double noundef %1) unnamed_addr #2 {
  5006     %3 = call { double, double } @fwddiffe_f1(double %0, double %1)
   ret { double, double } %3
 }                 

But once you look inside the enzyme generated function, it calls one function per original instruction. I.e. we have 5 or 6 function calls beyond the extra indirection mentioned above. Of course, for such simple code that will completely kill the performance. Luckily enzyme has some debug flags described here: https://enzyme.mit.edu/index.fcgi/rust/Debugging.html So now we run RUSTFLAGS="-Z autodiff=PrintModAfterEnzyme,Inline" cargo +enzyme build --release &> mod.ll and get something much more sensible:

; Function Attrs: mustprogress noinline nonlazybind willreturn uwtable
define internal { double, double } @fwddiffe_f2(double noundef "enzyme_type"="{[-1]:Float@double}" %0, double "enzyme_type"="{[-1]:Float@double}" %1) unnamed_addr #122 personality ptr @rust_eh_personality {
  call void @llvm.experimental.noalias.scope.decl(metadata !45833) #127
  %3 = call noundef double @llvm.exp.f64(double %0) #127
  %4 = call fast double @llvm.exp.f64(double %0)
  %5 = fmul fast double %1, %4
  call void @llvm.experimental.noalias.scope.decl(metadata !45836) #127
  %6 = call noundef double @llvm.sin.f64(double %0) #127
  %7 = call fast double @llvm.cos.f64(double %0)
  %8 = fmul fast double %1, %7
  call void @llvm.experimental.noalias.scope.decl(metadata !45839) #127
  %9 = call noundef double @llvm.powi.f64.i32(double %6, i32 3) #127
  %10 = fcmp fast oeq double %8, 0.000000e+00
  %11 = and i1 false, %10
  %12 = or i1 false, %11
  %13 = call fast double @llvm.powi.f64.i32(double %6, i32 2)
  %14 = fmul fast double 3.000000e+00, %13
  %15 = fmul fast double %8, %14
  %16 = select fast i1 %12, double 0.000000e+00, double %15
  call void @llvm.experimental.noalias.scope.decl(metadata !45842) #127
  %17 = call noundef double @llvm.cos.f64(double %0) #127
  %18 = call fast double @llvm.sin.f64(double %0)
  %19 = fneg fast double %18
  %20 = fmul fast double %1, %19
  call void @llvm.experimental.noalias.scope.decl(metadata !45845) #127
  %21 = call noundef double @llvm.powi.f64.i32(double %17, i32 3) #127
  %22 = fcmp fast oeq double %20, 0.000000e+00
  %23 = and i1 false, %22
  %24 = or i1 false, %23
  %25 = call fast double @llvm.powi.f64.i32(double %17, i32 2)
  %26 = fmul fast double 3.000000e+00, %25
  %27 = fmul fast double %20, %26
  %28 = select fast i1 %24, double 0.000000e+00, double %27
  %29 = fadd double %9, %21
  %30 = fadd fast double %16, %28
  call void @llvm.experimental.noalias.scope.decl(metadata !45848) #127
  %31 = call noundef double @llvm.sqrt.f64(double %29) #127
  %32 = fcmp fast ueq double %29, 0.000000e+00
  %33 = call fast double @llvm.sqrt.f64(double %29) #128
  %34 = fmul fast double 2.000000e+00, %33
  %35 = fdiv fast double %30, %34
  %36 = select fast i1 %32, double 0.000000e+00, double %35
  %37 = fdiv double %3, %31
  %38 = fmul fast double %5, %31
  %39 = fmul fast double %36, %3
  %40 = fsub fast double %38, %39
  %41 = fmul fast double %31, %31
  %42 = fdiv fast double %40, %41
  %43 = insertvalue { double, double } undef, double %37, 0
  %44 = insertvalue { double, double } %43, double %42, 1
  ret { double, double } %44
}

Benchmarking times aren't affected though, mine were better for enzyme from the beginning (415 instead of 500 which you have on your repo), but they are the same for me with and without the flag.

Enzyme: 1st order/forward
                        time:   [415.55 ps 415.68 ps 415.86 ps]
num-dual: 1st order/Dual64
                        time:   [308.20 ps 308.84 ps 309.84 ps]

Now, that we know that Enzyme's indirection is likely the issue, let's handycap num-dual and bench indirection:

+pub fn indirection<D: DualNum<f64>>(x: D) -> D {
+    f1(x)
+}
+
+#[inline(never)]
 pub fn f1<D: DualNum<f64>>(x: D) -> D {
     x.exp() / (x.sin().powi(3) + x.cos().powi(3)).sqrt()
 }

Now we get:

num-dual: 1st order/Dual64
                        time:   [408.21 ps 408.69 ps 409.33 ps]
                        change: [+30.526% +31.881% +33.197%] (p = 0.00 < 0.05)
                        Performance has regressed.

And indeed, it's down to Enzyme level. So in summary enzyme currently always has one more indirection because I didn't bother with cleaning up the llvm-ir enough. I never noticed because it's easily covered by every slightly more complex operation, but here we just have 5 simple operations, due to which it actually has an effect. I can't promise to fix it too soon since it likely won't be measurable beyond toy examples, but I'll leave it open as a reminder. It would also be an easy way to get started, it shouldn't be too hard for a new contributor.

ZuseZ4 commented 3 months ago

I pushed https://github.com/EnzymeAD/rust/commit/ab5490415bb28ef6cdd8045d116a9951aa171c8e to remove one layer of indirection, but interesting enough it had no performance impact. I'll look if I can inline even the call to the differentiated function, that might help. In the meantime, please feel free to post the llvm-ir of your function, cargo has a flag for that. Maybe we can spot the difference that way?

ZuseZ4 commented 3 months ago

@g-bauer Eventually this could be the reason why we're slower. There is a whole discussion on correctness in AD here: https://github.com/EnzymeAD/Enzyme/issues/1295 Do you have special handling in your tool for sqrt(0)? I tend to not adjust the default behaviour even if it has a small performance overhead, since for non-toy examples the perf benefits of LLVM based AD should easily cover this perf overhead. Do you have any larger benchmarks on which we could compare?

g-bauer commented 3 months ago

We don't have special treatment (see here). Taking the derivative of sqrt(0.0) will return NaN. But I don't think that's the issue here. Changing to a different operation (pow, ln, ...) doesn't change the results in the benchmark on my machine.

I'll add a longer example.