EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.
1.24k stars 104 forks source link

Syntax for custom forward rule in C? #1991

Closed cgeoga closed 1 month ago

cgeoga commented 1 month ago

Hi all,

Apologies if I missed this somewhere in an existing issue or in the docs, but I'd like to implement a custom forward rule in some C code (although I am willing to switch to C++ if that would materially simplify the solution). When I search for things like "custom forward" in the issue and PR list, I see many PRs with commit titles that seem pretty relevant, but I am not finding docs or examples.

In particular, I have the following C function:

double levin_transform(double* s, double* w) {
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) return s[i];
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
  return a[1].x1/a[1].x2;

My issue is that my actual code creates circumstances where the primal evaluation of levin_transform hits the early return branch, but the evaluation of the function with the dual values shouldn't.

What I really want to do is write something like this:

double _levin_transform(double* s, double* w) {
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) return s[i];
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
  return a[1].x1/a[1].x2;

double levin_transform(double* s, double* w) {
  _levin_transform(s, w);

double __enzyme_fwddiff_levin_transform(double* s, double* ds, double* w, double* dw) {
  double primal_value = _levin_transform(s, w);
  double dual_value   = _levin_transform(ds, dw);
  // return format somehow?    

This is more or less what we did with this function in the Julia package Bessels.jl, and it works great. But I'm having trouble getting this into the C code.

The closest thing I can find is the code example in #1655. But I'm curious if there is a simpler or newer solution. I am using the Enzyme from Spack, which appears to be 0.0.81, if that is relevant here. I am also happy to upgrade to whatever sufficiently new version offers the best solution here.

Thanks so much!

wsmoses commented 1 month ago

Not nearly as pretty or customizable yet, but here ya go: https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Integration/ForwardMode/customfwd.c

cgeoga commented 1 month ago

Super helpful, thanks @wsmoses! I'm trying this syntax here and I can get the code to compile, but I'm having trouble actually hitting the custom rule. Here is a MWE that I compile with

$(CLANG) mwe.c -fplugin=$(CLANG_ENZYME) -O3 -lm -Wall -pedantic -o mwe

(where CLANG and CLANG_ENZYME are disgusting long Makefile variables). The MWE code is


struct doublepair {
  double x1;
  double x2;

// Enzyme setup:
int enzyme_const, enzyme_dup, enzyme_dupnoneed;
double __enzyme_fwddiff(void*, ...);

struct doublepair fmadd(struct doublepair a, double b, struct doublepair c){
  double o1 = (a.x1 * b + c.x1);
  double o2 = (a.x2 * b + c.x2);
  struct doublepair out;
  out.x1 = o1;
  out.x2 = o2;
  return out;

double levin_scale(double B, int n, int k) {
  return -(B+n+k)*(B+n+k-1)/((B+n+2*k)*(B+n+2*k-1));

static void _levin(double* out, double* sw) {
  double* s = sw;
  double* w = sw + (int)16;
  struct doublepair a[16];
  double sc;
  for(int i=0; i<16; i++){
    if(w[i] == 0.0) {
      printf("Hitting early return...\n");
      *out = s[i]; 
    a[i].x1 = s[i]/w[i];
    a[i].x2 = 1.0/w[i];
  for(int k=0; k<15; k++){
    for(int i=0; i<(15-k); i++){
      sc   = levin_scale(1.0, (double)(i+1), (double)k);
      a[i] = fmadd(a[i], sc, a[i+1]);
  *out = a[1].x1/a[1].x2;

static void levin(double* out, double* sw){
  _levin(out, sw);

static void derivative_levin(double* out, double* d_out, 
                      double* sw,  double* d_sw) {
  printf("Primal Levin pass...\n");
  _levin(out, sw);
  printf("Dual Levin pass...\n");
  _levin(d_out, d_sw);

void* __enzyme_register_derivative_levin[] = {

double besselkx_levin(double v, double x) {
  double out, s, t;
  s = 0.0; 
  t = 1.0;
  double sequences_weights[32];
  double* sequences = sequences_weights;
  double* weights   = sequences_weights + 16;
  double fvv = 4*v*v;
  double eightx = 8*x;
  for(int k=0; k<16; k++) {
    s += t;
    t *= (fvv - pow(2*k+1, 2))/(eightx*(k+1));
    sequences[k] = s; 
    weights[k]   = t; 
  levin(&out, sequences_weights);
  out *= sqrt(M_PI/(2*x));
  return out;

double dbesselkx_levin_dv(double v, double x) {
  double dv = 1.0;
  return __enzyme_fwddiff((void*) besselkx_levin, 
                          enzyme_dupnoneed, v, dv,
                          enzyme_const, x);

int main(int argc, char** argv) {
  double v = 0.5;
  double x = 1.51;
  printf("%1.5e\n", dbesselkx_levin_dv(v, x));
  return 0;

When I compile this code (which goes without errors or warnings) and run ./mwe, I get a print for the early return (which should happen) but I don't get a print statement for the custom registered derivative of levin getting triggered, and I get the incorrect derivative of 0.00000.

Can you provide any pointers or hints about what I should be doing to make this custom derivative get triggered? I see from other issues there is also a __enzyme_register_gradient that takes three arguments, but my (potentially incorrect) assessment is that that is for reverse mode?

cgeoga commented 1 month ago

As an update here, on my Ubuntu LTS server the latest spack version that I seemed to be able to download was 0.0.81. Using the AUR package, which gives version 0.0.132, everything works as expected here. So i guess there is just a minimal version requirement for these tools.

This issue can probably be closed, but I'll let you decide on that. Let me know if you'd like me to submit a PR for documentation somewhere letting people know to try a sufficiently recent version or something.

wsmoses commented 1 month ago

Oh yeah the spack packages are often a bit out of date, but glad it was fixed!

In the interim I made a quick PR to spack to bump: https://github.com/spack/spack/pull/45346