EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.
https://enzyme.mit.edu
Other
1.28k stars 108 forks source link

Quadratic memory usage (mk iii) #132

Closed unrealwill closed 3 years ago

unrealwill commented 3 years ago

Hello,

The following instruction F[indj*d+l] += wjk * parts[indk*d+l];

Make the code needs quadratic memory in the backward pass.

#include <stdio.h>
#include <iostream>
#include <stdlib.h>
#include <random>
#include <math.h>
#include <vector>
#include <algorithm>

using namespace std;

struct Index
{
    int* cellId;
    int* start;
    int* cellSize;
    int size;
    int* argsorted;
    int n;
} ;

void buildIndex( Index& index , double * parts, int n )
{
    int d = 3;
    index.n = n;
    index.cellId = new int[n];
    index.start = new int[n];
    index.cellSize = new int[n]; //Max Size is n but the end may be unused
    index.argsorted = new int[n];

    for( int i = 0 ; i < n ; i++)
    {
        int id = parts[d*i];
        index.cellId[i] = id;
    }

    vector<pair<int,int> > v(n);
    for( int i = 0 ; i < n ; i++)
    {
        v[i].first = index.cellId[i];
        v[i].second = i;
    }

    sort( v.begin(), v.end() );
    int i = 0 ;
    int cur = -1;
    int curCellId = -1;
    for( int i = 0 ; i < n ; i++)
    {
        index.argsorted[i] = v[i].second;
        if( v[i].first == curCellId)
        {
            index.cellSize[cur]++;
        }
        else
        {
            curCellId = v[i].first;
            cur ++;
            index.cellSize[cur] = 1;
            index.start[cur] = i;
        }
    }
    index.size = cur+1;

}

double foo( double* __restrict__ parts,int n, Index* __restrict__ index)
{
     double out = 0;
     const int d = 3;

     double F[n*d];

     double W[n];

     for( int i = 0 ; i < n ; i++)
     {
         for( int j = 0 ; j < d ; j++)
         {
             F[i*d+j] = 0.0;
         }
         W[i] = 0.0;
     }

     for( int i = 0 ; i < index->size ; i++)
     {
         for( int j = 0 ; j < index->cellSize[i] ; j++ )
         {
             for( int k = 0 ; k < index->cellSize[i] ; k++ )
             {
                 int indj = index->argsorted[index->start[i]+j];
                 int indk = index->argsorted[index->start[i]+k];

                 double djk = 0;
                 for( int l = 0 ; l < d ; l++)
                 {
                     double temp;
                     temp = parts[indj * d +l ]- parts[indk * d +l ];
                     djk += temp*temp;
                 }
                 //out += djk;

                 double wjk = 1.0+djk; // strictly positive

                 for( int l = 0 ; l < d ; l++)
                 {
                     F[indj*d+l] += wjk * parts[indk*d+l];
                 }

                 //W[indj] += wjk;

            }
         }
     }

     /*
    //Normalize the field value
    for( int i = 0 ; i < n ; i++)
    {
        for( int j = 0 ; j < d ; j++)
        {
            F[i*d+j] /= W[i*d+j];
        }
    }
*/
/*
    //Compute the energy
    for( int i = 0 ; i < n ; i++)
    {
        double e = 0.0;
        for( int j = 0 ; j < d ; j++)
        {
            out += F[i*d+j]*F[i*d+j];
        }
    }
*/

     //delete[] F;
     //delete[] W;

     return out;
}

int enzyme_dup;
int enzyme_out;
int enzyme_const;

typedef double (*f_ptr)(double *,int,Index*);

extern double __enzyme_autodiff(f_ptr,
    int, double *, double *,
    int, int,
    int, Index*);

int main() {
    std::mt19937 e2(42);
    std::uniform_real_distribution<> dist(0, 10);
    int n = 100000;
    int d = 3;
    double* x = new double[n*d];
    double* d_x = new double[n*d];
    for( int i = 0 ; i < n*d ; i++)
    {
        x[i] = dist(e2);
        d_x[i] = 0.0;
    }

    Index index;
    buildIndex(index, x, n);

    for( int i = 0 ; i < 100 ; i++)
    {
    printf("cellId[%d] = %d\n ",i, index.cellId[i]);
    }

    printf("before autodiff\n");
    __enzyme_autodiff(foo,
        enzyme_dup, x, d_x,
        enzyme_const, n,
        enzyme_const, &index);

    //printf("%f \n", y);
    for( int i = 0 ; i < 100 ; i++)
    {
    printf("dx[%d] = [%f, %f, %f]\n ",i, d_x[d*i],d_x[d*i+1],d_x[d*i+2]);
    }

}

Compiled with : clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-7.so -O2 -o test2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions

unrealwill commented 3 years ago

The quadratic memory usage seems caused by djk being kept in memory during the forward pass to be available for the backward pass, instead of being recomputed during the backward pass.

Is there a way to tell enzyme to recompute instead of keeping the memory ?

wsmoses commented 3 years ago

So this actually makes sense for this code. In particular what is happening is that Enzyme decides it needs to cache the final value if djk for use in the reverse pass. This then leads to a potential O(size * cellsize[i]^2) cache.

Of course, in this particular case Enzyme doesn't need to cache djk and chose to do so because it was computed within a loop (thinking that recomputing it may be expensive).

The quickest way to resolve it in this piece of code, is to actually remove the -fno-unroll-loops flag (the flag tells all loops not to be unrolled, which as you can tell has various implications for cache decisions). Without the flag when I test it the loop gets fully unrolled and thus recomputed in the reverse pass without additional memory.

This does, however, result in the GVN bug in LLVM 7 you saw earlier so you'll likely need a later version of LLVM for this as well.

Hopefully that resolves your issue for the time being.

I'm going to leave this open for now, however, to add some additional mechanisms for controlling the cache vs recompute decisions.

unrealwill commented 3 years ago

Thanks !

I grabbed a source zip of llvm11 release https://github.com/llvm/llvm-project/archive/release/11.x.zip Then followed the same installation procedure as for the 7.0 version and updated my makefile so the compilation now is :

clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-11.so -O2 -o test2 -fno-exceptions

And it works fine ! :+1:

unrealwill commented 3 years ago

It's back again, when adding the cubic spline kernel, the optimizer does not make the right choice, which result in memory explosion.

Additional code :

inline double cubicSpline3d( double d, double h )
{
    double q = d / h;
    double sigma3 = 1.0/ (M_PI*h*h*h);

    if( q <= 1 )
    {
        return sigma3*(1.0-3.0/2.0*q*q*(1.0-q/2.0));
    }
    else if( q<=2 )
    {
        double temp = (2.0-q);
        return sigma3/4.0*temp*temp*temp;
    }
    else
    {
        return 0.0;
    }

}
...
     //inside loops
    djk = sqrt( djk); // works fine without additional memory
    double wjk = cubicSpline3d( djk, h ); // doesn't seem to like the "if" as each individual branches can be optimized without  memory problem
...

test2.cpp :

#include <stdio.h>
#include <iostream>
#include <stdlib.h>
#include <random>
#include <math.h>
#include <vector>
#include <algorithm>

using namespace std;

struct Index
{
    int* cellId;
    int* start;
    int* cellSize;
    int size;
    int* argsorted;
    int n;
} ;

void buildIndex( Index& index , double * parts, int n )
{
    int d = 3;
    index.n = n;
    index.cellId = new int[n];
    index.start = new int[n];
    index.cellSize = new int[n]; //Max Size is n but the end may be unused
    index.argsorted = new int[n];

    for( int i = 0 ; i < n ; i++)
    {
        int id = parts[d*i];
        index.cellId[i] = id;
    }

    vector<pair<int,int> > v(n);
    for( int i = 0 ; i < n ; i++)
    {
        v[i].first = index.cellId[i];
        v[i].second = i;
    }

    sort( v.begin(), v.end() );
    int i = 0 ;
    int cur = -1;
    int curCellId = -1;
    for( int i = 0 ; i < n ; i++)
    {
        index.argsorted[i] = v[i].second;
        if( v[i].first == curCellId)
        {
            index.cellSize[cur]++;
        }
        else
        {
            curCellId = v[i].first;
            cur ++;
            index.cellSize[cur] = 1;
            index.start[cur] = i;
        }
    }
    index.size = cur+1;

}

inline double cubicSpline3d( double d, double h )
{
    double q = d / h;
    double sigma3 = 1.0/ (M_PI*h*h*h);

    if( q <= 1 )
    {
        return sigma3*(1.0-3.0/2.0*q*q*(1.0-q/2.0));
    }
    else if( q<=2 )
    {
        double temp = (2.0-q);
        return sigma3/4.0*temp*temp*temp;
    }
    else
    {
        return 0.0;
    }

}

double foo( double* __restrict__ parts,int n, Index* __restrict__ index)
{
     double out = 0;
     const int d = 3;
     double h = 1.0;

     double F[n*d];

     double W[n];

     for( int i = 0 ; i < n ; i++)
     {
         for( int j = 0 ; j < d ; j++)
         {
             F[i*d+j] = 0.0;
         }
         W[i] = 0.0;
     }

     for( int i = 0 ; i < index->size ; i++)
     {
         for( int j = 0 ; j < index->cellSize[i] ; j++ )
         {
             for( int k = 0 ; k < index->cellSize[i] ; k++ )
             {
                 int indj = index->argsorted[index->start[i]+j];
                 int indk = index->argsorted[index->start[i]+k];

                 double djk = 0;
                 for( int l = 0 ; l < d ; l++)
                 {
                     double temp;
                     temp = parts[indj * d +l ]- parts[indk * d +l ];
                     djk += temp*temp;
                 }
                 //out += djk;

                 //double wjk = 1.0+djk; // strictly positive
                 //djk = sqrt( djk );
                 //double wjk = cubicSpline3d( djk, h );
                 djk = sqrt( djk);
                 double wjk = cubicSpline3d( djk, h );

                 for( int l = 0 ; l < d ; l++)
                 {
                     F[indj*d+l] += wjk * parts[indk*d+l];
                 }

                 W[indj] += wjk;

            }
         }
     }

    //Normalize the field value
    for( int i = 0 ; i < n ; i++)
    {
        for( int j = 0 ; j < d ; j++)
        {
            F[i*d+j] /= W[i];
        }
    }

    //Compute the energy
    for( int i = 0 ; i < n ; i++)
    {
        for( int j = 0 ; j < d ; j++)
        {
            out += F[i*d+j]*F[i*d+j];
        }
    }

     //delete[] F;
     //delete[] W;

     return out;
}

int enzyme_dup;
int enzyme_out;
int enzyme_const;

typedef double (*f_ptr)(double *,int,Index*);

extern double __enzyme_autodiff(f_ptr,
    int, double *, double *,
    int, int,
    int, Index*);

int main() {
    std::mt19937 e2(42);
    std::uniform_real_distribution<> dist(0, 10);
    int n = 100000;
    int d = 3;
    double* x = new double[n*d];
    double* d_x = new double[n*d];
    for( int i = 0 ; i < n*d ; i++)
    {
        x[i] = dist(e2);
        d_x[i] = 0.0;
    }

    Index index;
    buildIndex(index, x, n);

    for( int i = 0 ; i < 100 ; i++)
    {
    printf("cellId[%d] = %d\n ",i, index.cellId[i]);
    }

    printf("before autodiff\n");
    __enzyme_autodiff(foo,
        enzyme_dup, x, d_x,
        enzyme_const, n,
        enzyme_const, &index);

    //printf("%f \n", y);
    for( int i = 0 ; i < 100 ; i++)
    {
    printf("dx[%d] = [%f, %f, %f]\n ",i, d_x[d*i],d_x[d*i+1],d_x[d*i+2]);
    }
}

clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-11.so -O2 -o test2 -fno-exceptions

wsmoses commented 3 years ago

Haven't yet tried the code, but the default behavior of the caching heuristic is to cache function calls results (as they may take > O(1) time), with exceptions for libm functions.

My guess is that's the cause here. Per previous discussion, currently adding some mechanisms for letting the user specify some more of these decisions, but also in this case since your function itself only uses O(1) functions and no loops, we should upgrade the default recompute heuristic.

A quick option that may be helpful for testing, by the way: try adding the -enzyme-inline which forces all of the (non-recursive) called functions to be inlined. This won't necessarily resolve this in your case for the moment, but thought it may be worthwhile to point out.

unrealwill commented 3 years ago

Good idea with the -enzyme-inline , I tried it by adding -mllvm --enzyme-inline clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-11.so -O2 -o test2 -fno-exceptions -mllvm --enzyme-inline Unfortunately it doesn't work here.

But it gave me the idea to do the opposite to help the compiler.

I extracted the inner code of the loop to a function and the code with the cubic spline optimize without quadratic memory now.

void inner(int i,int j, int k, double* __restrict__ parts, double* __restrict__ F, double* __restrict__ W, Index* __restrict__ index )

for( int i = 0 ; i < index->size ; i++)
        for( int j = 0 ; j < index->cellSize[i] ; j++ )
            for( int k = 0 ; k < index->cellSize[i] ; k++ )
                inner(i,j,k,parts,F,W,index);

and compiled it without the forced inlining : clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-11.so -O2 -o test2 -fno-exceptions

It's seems it's kind of the opposite inlining convention from tapenade. In tapenade, when you want something to be recomputed you put it inside the loop, but you extract it as a function if you want to keep values from the forward pass inside the stack.

wsmoses commented 3 years ago

I'm actually a bit surprised that did that and will investigate.

FWIW, yesterday I also started an upgraded version of the recompute heuristic (to deal with PHINodes in LLVM parlance): https://github.com/wsmoses/Enzyme/pull/139

It's certainly not stable yet (and breaks many tests it shouldn't), but it does indeed result in the cubicSpline3d results getting recomputed. I'll work on getting that stable and into master.

Also as an FYI, a couple of other minor things that I think you may want to do (but won't change the recompute heuristic here):

Also I think I understand the core of the issue a tiny bit better and it may actually alternatively be fixed by doing the opposite of my previous thought at marking cubicSpline3d as noinline. In essence the reason this is happening actually has nothing to do with loops, but rather the result of that function results in phi node (which says pick the value depending on the predecessor). We currently don't recompute phi-nodes for risk of them loop induction variables (which is what my WIP PR removes), as a result when your code is run normally the function gets inlined and you end up with something needing to be cached. In contrast when its not inlined, it still needs to be cached in the current system, but the cache is only done within the single call to reverse pass of cubicSpline (e.g. not caching all forward passes then running).

Again, clearly the fix here is being more aggressive about PHINodes, but thought I'd share this tidbit of learned wisdom.

unrealwill commented 3 years ago

Here are some additional things I'd like to do, hopefully they will work with your new recompute heuristic :

Inside the loops, I'd like to compute the gradient of the kernel function. (aka gradient of a cubic spline) It works without quadratic memory issue when I write this gradient manually. But If I tried to use Enzyme in a nested manner, it works for low n but require some additional memory and therefore crash.

Extract of the inner function :

double dist3d2(  double* __restrict__ v)
{
    double out = 0.0;
    for(int i = 0;i < 3 ;i++)
    {
        out += v[i]*v[i];
    }
    return out;
}

double kernel( double* __restrict__ v , double h)
{
    double d2 = dist3d2(v);
    return d2;
}

void gradKernel( double* v, double* dv, double h)
{
    double d2 = dist3d2(v);
    for( int i = 0 ; i < 3 ; i++)
    {
        dv[i] = 2.0* v[i];
    }
}

void autodiffgradKernel( double*__restrict__ v, double* __restrict__ dv, double h)
{
    __enzyme_autodiff(kernel,enzyme_dup,v,dv,enzyme_const, h);
}

void inner(int i,int j, int k, double* __restrict__ parts, double* __restrict__ F, double* __restrict__ W, Index* __restrict__ index )
{
    const int d = 3;
    double h = 1.0;

    int indj = index->argsorted[index->start[i]+j];
    int indk = index->argsorted[index->start[i]+k];

    double v[d];
    double djk = 0;
    for( int l = 0 ; l < d ; l++)
        v[l] = parts[indj * d +l ] - parts[indk * d +l ];

    double wjk = kernel( v , h);
    double dv[d];

    //gradKernel( v, dv,h); //The manual version works
    autodiffgradKernel(v,dv,h); //The autodiff one need extra memory therefore crash

    for( int l = 0 ; l < d ; l++)
    {
        F[indj*d+l] += wjk * parts[indk*d+l];
    }

    for( int m = 0 ; m < d ; m++)
    for( int l = 0 ; l < d ; l++)
    {
        F[indj*d+l] +=  wjk * dv[m]* parts[indk*d+l];
    }
    W[indj] += wjk;
}

test2.cpp

#include <stdio.h>
#include <iostream>
#include <stdlib.h>
#include <random>
#include <math.h>
#include <vector>
#include <algorithm>

using namespace std;

struct Index
{
    int* __restrict__ cellId;
    int* __restrict__ start;
    int* __restrict__ cellSize;
    int size;
    int* __restrict__ argsorted;
    int n;
} ;

int enzyme_dup;
int enzyme_out;
int enzyme_const;

void __enzyme_autodiff(...);

void buildIndex( Index& index , double * parts, int n )
{
    int d = 3;
    index.n = n;
    index.cellId = new int[n];
    index.start = new int[n];//Max Size is n but the end may be unused
    index.cellSize = new int[n]; //Max Size is n but the end may be unused
    index.argsorted = new int[n];

    for( int i = 0 ; i < n ; i++)
    {
        int id = parts[d*i];
        index.cellId[i] = id;
    }

    vector<pair<int,int> > v(n);
    for( int i = 0 ; i < n ; i++)
    {
        v[i].first = index.cellId[i];
        v[i].second = i;
    }

    sort( v.begin(), v.end() );
    int i = 0 ;
    int cur = -1;
    int curCellId = -1;
    for( int i = 0 ; i < n ; i++)
    {
        index.argsorted[i] = v[i].second;
        if( v[i].first == curCellId)
        {
            index.cellSize[cur]++;
        }
        else
        {
            curCellId = v[i].first;
            cur ++;
            index.cellSize[cur] = 1;
            index.start[cur] = i;
        }
    }
    index.size = cur+1;

}

double dist3d2(  double* __restrict__ v)
{
    double out = 0.0;

    for(int i = 0;i < 3 ;i++)
    {
        out += v[i]*v[i];
    }
    return out;
}

double cubicSpline3d( double d, double h )
{
    double q = d / h;
    double sigma3 = 1.0/ (M_PI*h*h*h);

    if( q <= 1 )
    {
        return sigma3*(1.0-3.0/2.0*q*q*(1.0-q/2.0));
    }
    else if( q<=2 )
    {
        double temp = (2.0-q);
        return sigma3/4.0*temp*temp*temp;
    }
    else
    {
        return 0.0;
    }
}

double kernel( double* __restrict__ v , double h)
{
    double d2 = dist3d2(v);
    return d2;
}

void gradKernel( double* v, double* dv, double h)
{
    double d2 = dist3d2(v);
    for( int i = 0 ; i < 3 ; i++)
    {
        dv[i] = 2.0* v[i];
    }
}

void autodiffgradKernel( double*__restrict__ v, double* __restrict__ dv, double h)
{
    __enzyme_autodiff(kernel,enzyme_dup,v,dv,enzyme_const, h);
}

void inner(int i,int j, int k, double* __restrict__ parts, double* __restrict__ F, double* __restrict__ W, Index* __restrict__ index )
{
    const int d = 3;
    double h = 1.0;

    int indj = index->argsorted[index->start[i]+j];
    int indk = index->argsorted[index->start[i]+k];

    double v[d];
    double djk = 0;
    for( int l = 0 ; l < d ; l++)
        v[l] = parts[indj * d +l ] - parts[indk * d +l ];

    double wjk = kernel( v , h);
    double dv[d];

    gradKernel( v, dv,h); //The manual version works
    //autodiffgradKernel(v,dv,h); //The autodiff one need extra memory therefore crash

    for( int l = 0 ; l < d ; l++)
    {
        F[indj*d+l] += wjk * parts[indk*d+l];
    }

    for( int m = 0 ; m < d ; m++)
    for( int l = 0 ; l < d ; l++)
    {
        F[indj*d+l] +=  wjk * dv[m]* parts[indk*d+l];
    }

    W[indj] += wjk;
}

void foo( double* __restrict__ parts,int n, Index* __restrict__ index, double* __restrict__ out)
{
     *out = 0;
     const int d = 3;

     double F[n*d];
     double gradF[n*d*d];

     double W[n];

     for( int i = 0 ; i < n ; i++)
         for( int j = 0 ; j < d ; j++)
             F[i*d+j] = 0.0;

     for( int i = 0 ; i < n ; i++)
         W[i] = 0.0;

for( int i = 0 ; i < index->size ; i++)
{
    int cellsizei = index->cellSize[i];
    for( int j = 0 ; j < cellsizei; j++ )
        for( int k = 0 ; k < cellsizei ; k++ )
            inner(i,j,k,parts,F,W,index);
}

    //Normalize the field value
    for( int i = 0 ; i < n ; i++)
        for( int j = 0 ; j < d ; j++)
            F[i*d+j] /= W[i];

    //Compute the energy
    for( int i = 0 ; i < n ; i++)
        for( int j = 0 ; j < d ; j++)
            *out += F[i*d+j]*F[i*d+j];

     //delete[] F;
     //delete[] W;
}

int main() {
    std::mt19937 e2(42);
    std::uniform_real_distribution<> dist(0, 10);
    int n = 100000;
    int d = 3;
    double* x = new double[n*d];
    double* d_x = new double[n*d];
    for( int i = 0 ; i < n*d ; i++)
    {
        x[i] = dist(e2);
        d_x[i] = 0.0;
    }

    Index index;
    buildIndex(index, x, n);

    for( int i = 0 ; i < 100 ; i++)
    {
    printf("cellId[%d] = %d\n ",i, index.cellId[i]);
    }
    double e = 0.0;
    double de = 1.0;

    printf("before autodiff\n");
    __enzyme_autodiff(foo,
        enzyme_dup, x, d_x,
        enzyme_const, n,
        enzyme_const, &index,
        enzyme_dup,&e,&de);

    //printf("%f \n", y);
    for( int i = 0 ; i < 100 ; i++)
    {
    printf("dx[%d] = [%f, %f, %f]\n ",i, d_x[d*i],d_x[d*i+1],d_x[d*i+2]);
    }

}

Compilation with : clang test2.cpp -lstdc++ -lm -Xclang -load -Xclang /usr/local/lib/ClangEnzyme-11.so -O2 -o test2 -fno-exceptions

wsmoses commented 3 years ago

Will definitely look at. I think the difference here comes from a minor semantic difference.

Enzyme technically computes adjoints so the code it would produce is actually:

        dv[i] += 2.0* v[i];

rather than setting it. If the manual version uses +=, does it still use the small amount of memory? If so that means there's some information in nested AD that we want to make sure is being propagated and presumably not.

If it uses the large amount of memory then there's some different vanilla optimizations we would want LLVM to do and is not, meriting building some there.

unrealwill commented 3 years ago

It also works with += for the manual version. I've also added the initialization to zero that I had forgotten but it still crash from memory allocation for the automatic gradient

void gradKernel( double* v, double* dv, double h)
{
    double d2 = dist3d2(v);

    for( int i = 0 ; i < 3 ; i++)
    {
        dv[i] += 2.0* v[i];
    }
}
for( int i = 0 ; i < 3 ; i++)
    {
        dv[i] = 0.0;
    }
    //gradKernel( v, dv,h); //The manual version works
    autodiffgradKernel(v,dv,h);//The autodiff one need extra memory therefore crash
wsmoses commented 3 years ago

Ok, a couple of updates for independent parts of the things listed on the current github issue: 1) the cubicSpline3d allocation will be eliminated as soon as https://github.com/wsmoses/Enzyme/pull/139 merges. That PR allows for Enzyme to recompute values from non-loop based phi nodes (though does not make changes to the heuristic for when, trying to do so whenever possible). Again this may be wise to change the heuristic in the future, but that's much easier to do once it knows how to recompute those values 2) The nested AD issue is indeed caused by a lack of derived function attributes. In essence what is happening is that after performing the innermost AD, Enzyme is not running a "detect what function attributes apply to this newly generated function". This is useful to derive that the function doesn't read/write global memory, for example, and thus won't conflict with a read in the calling function. Thus when Enzyme performs AD on the caller, it ends up having to be conservative and assume that the function it generated from the inside could read/write to external memory and cache values unnecessarily. This should have a relatively simple fix of running infer attributes between the levels of Enzyme AD.

unrealwill commented 3 years ago

OK Thanks !

wsmoses commented 3 years ago

Ok https://github.com/wsmoses/Enzyme/pull/141 Should resolve 2) in your case.

After both PR's go in, Enzyme will have the desired memory behavior in your case, including using the nested AD call and as a bonus you won't need to do the inner function code being pulled out.