EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.
https://enzyme.mit.edu
Other
1.24k stars 104 forks source link

Syntax for custom forward rule in C? #1991

Closed cgeoga closed 1 month ago

cgeoga commented 1 month ago

Hi all,

Apologies if I missed this somewhere in an existing issue or in the docs, but I'd like to implement a custom forward rule in some C code (although I am willing to switch to C++ if that would materially simplify the solution). When I search for things like "custom forward" in the issue and PR list, I see many PRs with commit titles that seem pretty relevant, but I am not finding docs or examples.

In particular, I have the following C function:

double levin_transform(double* s, double* w) {
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) return s[i];
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  }
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
    }
  }
  return a[1].x1/a[1].x2;
}

My issue is that my actual code creates circumstances where the primal evaluation of levin_transform hits the early return branch, but the evaluation of the function with the dual values shouldn't.

What I really want to do is write something like this:

double _levin_transform(double* s, double* w) {
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) return s[i];
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  }
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
    }
  }
  return a[1].x1/a[1].x2;
}

double levin_transform(double* s, double* w) {
  _levin_transform(s, w);
}

double __enzyme_fwddiff_levin_transform(double* s, double* ds, double* w, double* dw) {
  double primal_value = _levin_transform(s, w);
  double dual_value   = _levin_transform(ds, dw);
  // return format somehow?    
}

This is more or less what we did with this function in the Julia package Bessels.jl, and it works great. But I'm having trouble getting this into the C code.

The closest thing I can find is the code example in #1655. But I'm curious if there is a simpler or newer solution. I am using the Enzyme from Spack, which appears to be 0.0.81, if that is relevant here. I am also happy to upgrade to whatever sufficiently new version offers the best solution here.

Thanks so much!

wsmoses commented 1 month ago

Not nearly as pretty or customizable yet, but here ya go: https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Integration/ForwardMode/customfwd.c

cgeoga commented 1 month ago

Super helpful, thanks @wsmoses! I'm trying this syntax here and I can get the code to compile, but I'm having trouble actually hitting the custom rule. Here is a MWE that I compile with

$(CLANG) mwe.c -fplugin=$(CLANG_ENZYME) -O3 -lm -Wall -pedantic -o mwe

(where CLANG and CLANG_ENZYME are disgusting long Makefile variables). The MWE code is

#include<stdlib.h>
#include<stdio.h>
#include<math.h>

struct doublepair {
  double x1;
  double x2;
};

// Enzyme setup:
int enzyme_const, enzyme_dup, enzyme_dupnoneed;
double __enzyme_fwddiff(void*, ...);

struct doublepair fmadd(struct doublepair a, double b, struct doublepair c){
  double o1 = (a.x1 * b + c.x1);
  double o2 = (a.x2 * b + c.x2);
  struct doublepair out;
  out.x1 = o1;
  out.x2 = o2;
  return out;
}

double levin_scale(double B, int n, int k) {
  return -(B+n+k)*(B+n+k-1)/((B+n+2*k)*(B+n+2*k-1));
}

static void _levin(double* out, double* sw) {
  double* s = sw;
  double* w = sw + (int)16;
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) {
      printf("Hitting early return...\n");
      *out = s[i]; 
      return;
    }
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  }
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
    }
  }
  *out = a[1].x1/a[1].x2;
}

static void levin(double* out, double* sw){
  _levin(out, sw);
}

static void derivative_levin(double* out, double* d_out, 
                      double* sw,  double* d_sw) {
  printf("Primal Levin pass...\n");
  _levin(out, sw);
  printf("Dual Levin pass...\n");
  _levin(d_out, d_sw);
}

void* __enzyme_register_derivative_levin[] = {
  (void*)levin,
  (void*)derivative_levin,
};

double besselkx_levin(double v, double x) {
  double out, s, t;
  s = 0.0; 
  t = 1.0;
  double sequences_weights[32];
  double* sequences = sequences_weights;
  double* weights   = sequences_weights + 16;
  double fvv = 4*v*v;
  double eightx = 8*x;
  for(int k=0; k<16; k++) {
    s += t;
    t *= (fvv - pow(2*k+1, 2))/(eightx*(k+1));
    sequences[k] = s; 
    weights[k]   = t; 
  }
  levin(&out, sequences_weights);
  out *= sqrt(M_PI/(2*x));
  return out;
}

double dbesselkx_levin_dv(double v, double x) {
  double dv = 1.0;
  return __enzyme_fwddiff((void*) besselkx_levin, 
                          enzyme_dupnoneed, v, dv,
                          enzyme_const, x);
}

int main(int argc, char** argv) {
  double v = 0.5;
  double x = 1.51;
  printf("%1.5e\n", dbesselkx_levin_dv(v, x));
  return 0;
}

When I compile this code (which goes without errors or warnings) and run ./mwe, I get a print for the early return (which should happen) but I don't get a print statement for the custom registered derivative of levin getting triggered, and I get the incorrect derivative of 0.00000.

Can you provide any pointers or hints about what I should be doing to make this custom derivative get triggered? I see from other issues there is also a __enzyme_register_gradient that takes three arguments, but my (potentially incorrect) assessment is that that is for reverse mode?

cgeoga commented 1 month ago

As an update here, on my Ubuntu LTS server the latest spack version that I seemed to be able to download was 0.0.81. Using the AUR package, which gives version 0.0.132, everything works as expected here. So i guess there is just a minimal version requirement for these tools.

This issue can probably be closed, but I'll let you decide on that. Let me know if you'd like me to submit a PR for documentation somewhere letting people know to try a sufficiently recent version or something.

wsmoses commented 1 month ago

Oh yeah the spack packages are often a bit out of date, but glad it was fixed!

In the interim I made a quick PR to spack to bump: https://github.com/spack/spack/pull/45346