EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.
https://enzyme.mit.edu
Other
1.28k stars 108 forks source link

Generating Jacobian matrix from a vector valued subroutine written in Fortran #954

Closed agoodm closed 1 year ago

agoodm commented 1 year ago

Hi there, I have just recently learned about the Enzyme project. I have read through some of the presentations and thought that it could be a very promising tool for my use case. One of the bits that caught my interest was the mention that it could work with any language that has LLVM support. Fortran is of interest in particular to me because my use-case involves running a fairly old Fortran 77 codebase for a radiative transfer model. The inputs are arrays for multiple physical quantities that are defined over a set of atmospheric layers, so the subroutine signature looks something like this:

subroutine calc_rad(T, H2O, CO2, ..., rad)
real, intent(in), dimension(nlayer) :: T
real, intent(in), dimension(nlayer) :: H2O
real, intent(in), dimension(nlayer) :: CO2
...
real, intent(out), dimension(nchan) :: rad
...

I wish to write another subroutine which computes the associated Jacobian matrix (nargs x nlayer rows by nchan columns) for this. Right now we are just manually calculating it via finite differences but due to the large size of the Jacobian matrix, this has a huge computational cost since it requires running the model hundreds of times. I have read that some examples have been done with Fortran, but all the examples I could find in the documentation were scalar functions in C++, so I am not quite sure where to start. I am currently thinking it should look something like this

subroutine calc_rad_jac(T, H2O, CO2, ..., rad)
real, intent(in), dimension(nlayer) :: T
real, intent(in), dimension(nlayer) :: H2O
real, intent(in), dimension(nlayer) :: CO2
...
real, intent(out), dimension(nchan) :: rad
real, intent(out), dimension(m, nchan) :: rad_jac ! m = nargs x nlayer
...
call __enzyme_autodiff(calc_rad, T, d_T, H2O, d_H2O, CO2, d_CO2, ..., rad, rad_jac)

Can I get some confirmation that I am taking the correct approach, and if not what changes would I need to make this work?

tgymnich commented 1 year ago

A good starting point would be to start calculating the Jacobian row wise using reverse mode inside a loop. If that works you can try replacing the loop with batched reverse mode. Your sketch is looking good just make sure to seed the shadows of your parameters appropriately.

Here is some pseudo-code:

subroutine calc_rad_jac(T, H2O, CO2, ..., rad)
real, intent(in), dimension(nlayer) :: T
real, intent(in), dimension(nlayer) :: H2O
real, intent(in), dimension(nlayer) :: CO2
...
real, intent(out), dimension(nchan) :: rad
real, intent(out), dimension(nchan) :: d_rad
real, intent(out), dimension(m, nchan) :: rad_jac ! m = nargs x nlayer
...

// set parameter shadow arrays to all zeros.
d_T = 0;
d_CO2 = 0;
d_H20 = 0;
d_rad = 0;

for i in range(0,nargs):
     d_rad[i] = 1
     call __enzyme_autodiff(calc_rad, T, d_T, H2O, d_H2O, CO2, d_CO2, ..., rad, d_rad)
     rad_jac[i, :] = [d_T, d_H20, d_CO2, ...]  // put d_T, d_H20, d_CO2, ... in the jacobian

     d_rad[i] = 0 // set this back to 0 so we can re-use the array in the next iteration
     d_T = 0;
     d_CO2 = 0;
     d_H20 = 0;
     d_rad = 0;
agoodm commented 1 year ago

Thanks @tgymnich !

Some follow-up comments and questions:

1) In your pseudo-code you loop over nargs, but I actually want to loop nlayers times since each of the elements of the arrays corresponding to nargs represents an element of the state-vector for the purposes of the model, so the Jacobian must therefore have nargs x nlayers rows. I presume then the parameter shadow args (d_T, d_H2O, d_CO2, etc) each should be length nchan (since this is reverse mode AD), while d_rad is length nlayers?

2) In principle I would prefer do do forward mode AD since for my use-case nchan > nlayer x nargs. Can enzyme support this, and if so how should I modify the above example?

tgymnich commented 1 year ago
  1. In your pseudo-code you loop over nargs, but I actually want to loop nlayers times since each of the elements of the arrays corresponding to nargs represents an element of the state-vector for the purposes of the model, so the Jacobian must therefore have nargs x nlayers rows.

You are right. range(0,nargs) was meant to be 0...num_jac_rows.

I presume then the parameter shadow args (d_T, d_H2O, d_CO2, etc) each should be length nchan (since this is reverse mode AD), while d_rad is length nlayers?

In general all shadows should have the same length as the values they are shadowing. After each iteration the shadows of the inputs will contain the derivatives with respect to the one vector element of the output vector which is non-zero.

  1. In principle I would prefer do do forward mode AD since for my use-case nchan > nlayer x nargs. Can enzyme support this, and if so how should I modify the above example?

For forward diff you can call __enzyme_fwddiff instead. The arguments should be the same. The Jacobian will now be calculated column wise: In your loop you would need to set the one input parameter with respect to which you want to differentiate to 1, then the output array (d_rad) will contain the derivative of each output parameter with respect to the one input you set before.

Once that works it should be pretty straight forward to go to batched forward mode.

agoodm commented 1 year ago

Thank you for clarifying and I realize that for reverse AD I would actually loop over nchan, while for forward AD it would be nlayer. To keep things simple I think for my first attempt it would probably be easiest to build the Jacobians for each input variable array separately and stack them together at the end.

One last thing, I have never used an LLVM Fortran compiler and I know of two options (flang, lfortran). Would either one be better than the other for using enzyme, and are there some important pitfalls I should be aware of? I do know that LLVM support is less complete for Fortran than for C/C++.

wsmoses commented 1 year ago

One comment here, often use cases actually only need a jacobian vector or vector jacobian product, which can be done with a single call to either forward or reverse mode.

There are several options for llvm-based fortran, to varying degrees of completion. Unless you exist in an lfortran environment, usually I point folks to either classic flang, or the new MLIR-based flang, either of which can work but come with different setup challenges and robust support for the language. The new flang is where most development is being done now days by the compiler devs, whereas classic flang is more robust (but has performance and other concerns).

agoodm commented 1 year ago

@wsmoses Thanks!

ludgerpaehler commented 1 year ago

@agoodm if you feel lost in LLVM Fortran + Enzyme, I am always happy to quickly jump on a call and help you get set up.

agoodm commented 1 year ago

EDIT: I fixed the problem I had below once I noticed that I was setting my temporary variable to the wrong type (int). Setting it to double fixed it.

So an update, to help get myself fully up to speed I am currently trying to implement a relatively simple example in C on my local machine first since I could easily install clang, llvm, and enzyme via homebrew. Here is the example:

// test.c
extern double __enzyme_fwddiff(void*, double[100], double[100], double[100], double[100]);
void f(double x[100], double out[100]) {
    int prev = 0;
    for(int i = 0; i < 100; i++) {
        out[i] = x[i] + prev/x[i];
        prev = x[i];
    }
}
void jvpf(double x[100], double v[100], double out[100], double dout[100]) {
    __enzyme_fwddiff((void*)f, x, v, out, dout);
}

Then compiled as follows:

clang test.c -S -emit-llvm -o input.ll -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops
opt input.ll -load=/opt/homebrew/Cellar/enzyme/0.0.45/lib/LLVMEnzyme-15.dylib -enzyme -o output.ll -S -enable-new-pm=0
opt output.ll -O2 -o output_opt.ll -S
clang output_opt.ll -dynamiclib -o libtest.dylib

Now I am running it in python via ctypes and comparing the output I get from jax:

import ctypes
import jax
import numpy as np

lib = ctypes.CDLL('/path/to/libteset.dylib')

# Enzyme JVP
def jvp(x, v):
    out = np.zeros(100)
    dout = np.zeros(100)
    args = []
    for a in [x, v, out, dout]:
        args.append(a.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
    lib.jvpf(*args)
    return dout

# Equivalent to f
@jax.jit
def f(x):
    prev = 0
    out = []
    for i in range(100):
        out.append(x[i] + prev/x[i])
        prev = x[i]
    out = jax.numpy.array(out)
    return out

When I run this I find that the enzyme implementation is much faster than jax even with the ctypes overhead included but the output isn't quite right. For example:

In [5]: x = np.arange(1, 101, dtype='float64') ** 2

In [6]: v = np.ones(100)

In [7]: jvp(x, v)
Out[7]: 
array([1.        , 0.9375    , 0.95061728, 0.96484375, 0.9744    ,
       0.98070988, 0.98500625, 0.98803711, 0.99024539, 0.9919    ,
       0.99316987, 0.99416474, 0.99495816, 0.99560079, 0.9961284 ,
       0.99656677, 0.9969349 , 0.99724699, 0.99751383, 0.99774375,
       0.99794324, 0.99811744, 0.99827045, 0.99840555, 0.99852544,
       0.99863231, 0.99872799, 0.99881397, 0.99889153, 0.99896173,
       0.99902547, 0.99908352, 0.99913654, 0.99918509, 0.99922965,
       0.99927067, 0.99930849, 0.99934345, 0.99937582, 0.99940586,
       0.99943378, 0.99945978, 0.99948403, 0.99950668, 0.99952788,
       0.99954773, 0.99956637, 0.99958387, 0.99960033, 0.99961584,
       0.99963046, 0.99964426, 0.99965731, 0.99966965, 0.99968133,
       0.99969241, 0.99970292, 0.9997129 , 0.99972238, 0.9997314 ,
       0.99973999, 0.99974818, 0.99975598, 0.99976343, 0.99977054,
       0.99977734, 0.99978383, 0.99979005, 0.999796  , 0.99980171,
       0.99980718, 0.99981242, 0.99981745, 0.99982229, 0.99982693,
       0.9998314 , 0.99983569, 0.99983982, 0.9998438 , 0.99984763,
       0.99985132, 0.99985488, 0.99985832, 0.99986163, 0.99986483,
       0.99986792, 0.9998709 , 0.99987379, 0.99987657, 0.99987927,
       0.99988188, 0.99988441, 0.99988685, 0.99988922, 0.99989152,
       0.99989374, 0.9998959 , 0.99989799, 0.99990002, 0.99990199])

In [8]: jax.jvp(f, [x], [v])[1]
Out[8]: 
Array([1.       , 1.1875   , 1.0617284, 1.0273438, 1.0144   , 1.0084877,
       1.0054144, 1.0036621, 1.002591 , 1.0019   , 1.0014343, 1.0011091,
       1.0008754, 1.0007029, 1.0005728, 1.000473 , 1.000395 , 1.0003334,
       1.000284 , 1.0002438, 1.0002108, 1.0001836, 1.0001608, 1.0001416,
       1.0001254, 1.0001116, 1.0000998, 1.0000895, 1.0000806, 1.0000728,
       1.000066 , 1.0000601, 1.0000548, 1.0000502, 1.000046 , 1.0000423,
       1.000039 , 1.000036 , 1.0000333, 1.0000309, 1.0000286, 1.0000267,
       1.0000249, 1.0000232, 1.0000217, 1.0000203, 1.0000191, 1.0000179,
       1.0000168, 1.0000159, 1.0000149, 1.0000141, 1.0000134, 1.0000126,
       1.0000119, 1.0000113, 1.0000107, 1.0000101, 1.0000097, 1.0000092,
       1.0000087, 1.0000083, 1.000008 , 1.0000076, 1.0000073, 1.0000069,
       1.0000066, 1.0000063, 1.0000061, 1.0000058, 1.0000056, 1.0000054,
       1.0000051, 1.0000049, 1.0000048, 1.0000045, 1.0000044, 1.0000042,
       1.000004 , 1.0000039, 1.0000037, 1.0000036, 1.0000035, 1.0000033,
       1.0000032, 1.0000031, 1.000003 , 1.0000029, 1.0000029, 1.0000027,
       1.0000026, 1.0000025, 1.0000025, 1.0000024, 1.0000023, 1.0000023,
       1.0000021, 1.0000021, 1.000002 , 1.000002 ], dtype=float32)

It seems jax is able to handle storing the input in temporary variables but enzyme can't. When I rewrite my example to remove the temporary variable it works: Important lesson I learned: Be careful to ensure any temporaries match the typing of the input, or else the AD will fail to propagate properly!

// test.c
extern double __enzyme_fwddiff(void*, double[100], double[100], double[100], double[100]);
void f(double x[100], double out[100]) {
    out[0] = x[0];
    for(int i = 1; i < 100; i++) {
        out[i] = x[i] + x[i-1]/x[i];
    }
}
void jvpf(double x[100], double v[100], double out[100], double dout[100]) {
    __enzyme_fwddiff((void*)f, x, v, out, dout);
}

Are there any good work arounds for this? The model's codebase I will need for my use case makes use of temporary variables in multiple spots and is much larger than this example, so going around and modifying them all would be difficult.

agoodm commented 1 year ago

So I managed to get forward mode fully working including a function for generating the full Jacobian matrix:

// test.c
int enzyme_dupnoneed;
void __enzyme_fwddiff(void*, ...);
void __enzyme_autodiff(void*, ...);
void f(double x[100], double out[100]) {
    double prev = 0;
    for(int i = 0; i < 100; i++) {
        out[i] = x[i] + prev/x[i];
        prev = x[i];
    }
}

void jvpf(double x[100], double v[100], double out[100], double dout[100]) {
    __enzyme_fwddiff((void*)f, x, v, enzyme_dupnoneed, out, dout);
}

void jacfwdf(double x[100], double out[100], double jac[100][100]) {
    double v[100] = {0};
    for(int i = 0; i < 100; i++) {
        v[i] = 1;
        __enzyme_fwddiff((void*)f, x, v, enzyme_dupnoneed, out, jac[i]);
        v[i] = 0;
    }
}

But when I try doing reverse mode AD, it fails:

void vjpf(double x[100], double dx[100], double out[100], double v[100]) {
    __enzyme_autodiff((void*)f, x, dx, enzyme_dupnoneed, out, v);
}

void jacrevf(double x[100], double out[100], double jac[100][100]) {
    double v[100] = {0};
    for(int i = 0; i < 100; i++) {
        v[i] = 1;
        __enzyme_autodiff((void*)f, x, jac[i], enzyme_dupnoneed, out, v);
        v[i] = 0;
    }
}

The failure happens when running the enzyme plugin on the IR:

Assertion failed: (NumContainedTys && "Attempting to get element type of opaque pointer"), function getNonOpaquePointerElementType, file Type.h, line 390.
PLEASE submit a bug report to https://github.com/Homebrew/homebrew-core/issues and include the crash backtrace.
Stack dump:
0.  Program arguments: opt input.ll -load=/opt/homebrew/Cellar/enzyme/0.0.45/lib/LLVMEnzyme-15.dylib -enzyme -o output.ll -S -enable-new-pm=0
1.  Running pass 'Enzyme Pass' on module 'input.ll'.
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libLLVM.dylib            0x000000010cec24a0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 56
1  libLLVM.dylib            0x0000000110366db8 SignalHandler(int) + 304
2  libsystem_platform.dylib 0x000000019cba74a4 _sigtramp + 56
3  libsystem_pthread.dylib  0x000000019cb8fee0 pthread_kill + 288
4  libsystem_c.dylib        0x000000019caca340 abort + 168
5  libsystem_c.dylib        0x000000019cac9754 err + 0
6  LLVMEnzyme-15.dylib      0x00000001058a1b6c llvm::Type::getNonOpaquePointerElementType() const (.cold.2) + 0
7  LLVMEnzyme-15.dylib      0x00000001056af12c std::__1::vector<int, std::__1::allocator<int>>::reserve(unsigned long) + 0
8  LLVMEnzyme-15.dylib      0x000000010585b50c CreateAllocation(llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>&, llvm::Type*, llvm::Value*, llvm::Twine, llvm::CallInst**, llvm::Instruction**, bool) + 360
9  LLVMEnzyme-15.dylib      0x00000001056c44b4 CacheUtility::createCacheForScope(CacheUtility::LimitContext, llvm::Type*, llvm::StringRef, bool, bool, llvm::Value*) + 600
10 LLVMEnzyme-15.dylib      0x0000000105825b30 GradientUtils::ensureLookupCached(llvm::Instruction*, bool, llvm::BasicBlock*, llvm::MDNode*) + 204
11 LLVMEnzyme-15.dylib      0x00000001058372e4 GradientUtils::lookupM(llvm::Value*, llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>&, llvm::ValueMap<llvm::Value const*, llvm::WeakTrackingVH, llvm::ValueMapConfig<llvm::Value const*, llvm::sys::SmartMutex<false>>> const&, bool) + 14868
12 LLVMEnzyme-15.dylib      0x000000010578f79c AdjointGenerator<AugmentedReturn const*>::lookup(llvm::Value*, llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>&) + 84
13 LLVMEnzyme-15.dylib      0x000000010578c18c AdjointGenerator<AugmentedReturn const*>::createBinaryOperatorAdjoint(llvm::BinaryOperator&) + 4860
14 LLVMEnzyme-15.dylib      0x000000010578ae30 AdjointGenerator<AugmentedReturn const*>::visitBinaryOperator(llvm::BinaryOperator&) + 256
15 LLVMEnzyme-15.dylib      0x0000000105701678 EnzymeLogic::CreatePrimalAndGradient(ReverseCacheKey const&&, TypeAnalysis&, AugmentedReturn const*, bool) + 9544
16 LLVMEnzyme-15.dylib      0x00000001056dc3d4 (anonymous namespace)::Enzyme::HandleAutoDiff(llvm::CallInst*, llvm::TargetLibraryInfo&, DerivativeMode, bool) + 10072
17 LLVMEnzyme-15.dylib      0x00000001056d7bc4 (anonymous namespace)::Enzyme::lowerEnzymeCalls(llvm::Function&, bool&, std::__1::set<llvm::Function*, std::__1::less<llvm::Function*>, std::__1::allocator<llvm::Function*>>&) + 8768
18 LLVMEnzyme-15.dylib      0x00000001056d35b8 (anonymous namespace)::Enzyme::runOnModule(llvm::Module&) + 8528
19 libLLVM.dylib            0x000000010d05e73c llvm::legacy::PassManagerImpl::run(llvm::Module&) + 1244
20 opt                      0x0000000104e0f734 main + 11116
21 dyld                     0x00000001051e108c start + 520

Is there also any information on how to use batch/vector mode? I see that there is a supported signature in the API for __enzyme_batch but it is unclear how to use it.

Thanks!

jdoerfert commented 1 year ago

Use -Xclang -no-opaque-pointers for your clang until Enzyme adapts to the "opaque pointer" changes.

tgymnich commented 1 year ago

Here is your example with vector mode:

void jacfwdvecf(double const* __restrict x, double* __restrict jac) {
    double out[100] = {0};
    double v[100][100] = {0};
    for(int i = 0; i < 100; i++)
        v[i][i] = 1;

    __enzyme_fwddiff((void*)f, enzyme_width, 100, 
                     enzyme_dupv, sizeof(double) * 100, x, v, 
                     enzyme_dupnoneedv, sizeof(double) * 100, out, jac);
}

https://fwd.gymni.ch/y12WMv

For forward vector mode you need to pass the vector width at compile time using enzyme_width. The annotation enzyme_dup becomes enzyme_dupv and enzyme_dupnoneed becomes enzyme_dupnoneedv. Both now require a trailing integer which specifies the byte size of one vector (does not need to be a compile time constant). It is also possible to just pass N buffers for each argument to Enzyme instead of using enzyme_dupv or enzyme_dupnoneedv.

Be mindful, that the amount of generated IR/code in Enzyme currently is proportional to the vector width. Thus we recommend setting the vector width to something close to the width of your system architecture.

agoodm commented 1 year ago

I can now run the reverse AD functions with those extra compiler flags added. The forward mode vector example compiles fine and seems to be at least an order of magnitude faster than the looped version, however the resulting Jacobian is all zeros. Any ideas?

tgymnich commented 1 year ago

Can you share your code? How does it differ from this: https://fwd.gymni.ch/y12WMv?

agoodm commented 1 year ago

My code was basically the same as yours but I didn't have a main function since I was testing it from inside python through ctypes. So I added the main function and found that when running it as an executable, it printed out the correct values of the Jacobian. However when I called the function externally in python through ctypes again I was still getting a zero Jacobian matrix. Then I looked at your version of jacfwdf and noticed the signature was a little different and didn't pass the output array out used for f, instead initializing it from inside the function. Calling your version of jacfwdf without the output array in the signature through ctypes also resulted in incorrect output, so I realized this was probably the culprit. After adding the out array as a dummy argument to jacfwdvecf like so:

void jacfwdvecf(double const* __restrict x, double* __restrict out, double* __restrict jac)

I now get the correct output when calling it through ctypes. Note that ultimately for my use-case I will be calling this code through python so it is essential for me to get this working correctly through ctypes.

The remaining downside is that after comparing the running times of both the vector and loop versions of the Jacobian calculation, I found that they were now identical. For the added complexity of calling __enzyme_fwddiff in vector mode, I was hoping this would give me some additional performance gain, but I am guessing it will probably still be a substantial improvement over what I currently have regardless, and I could also try reworking my use-case to use VJP/JVPs instead of the full Jacobian matrix.

I have also finally finished getting llvm/flang setup from source, so I should be ready to try moving on to trying this all in Fortran. I will let you know if I encounter any issues. Thanks for all the help!

tgymnich commented 1 year ago

High vector widths can cause a lot of register pressure. Have you tried using vector widths of 2, 4, 8, or 16?

agoodm commented 1 year ago

I realized I made a slight mistake when comparing the running times for each method, with a vector width of 100 it's actually twice as slow as running the looped version (30 us vs 15 us). So I played with setting the vector width and realized that setting it to less than 100 in this case results in only the number of rows equal to the width being set. This was tricky but I managed to get it working:

void jacfwdvecf(double x[100], double out[100], double jac[100][100]) {
    double v[100][100] = {0};
    const int nblocks = 25;
    const int block_size = 4;
    for(int block = 0; block < nblocks; block++) {
        int chunk = block*block_size;
        for(int i = 0; i < block_size; i++) {
            v[i][i+chunk] = 1;
        }
        __enzyme_fwddiff((void*)f, enzyme_width, block_size, 
                        enzyme_dupv, sizeof(double) * 100, x, v, 
                        enzyme_dupnoneedv, sizeof(double) * 100, out, jac[chunk]);
        for(int i = 0; i < block_size; i++) {
            v[i][i+chunk] = 0;
        }        
    }
}

Out of all the vector widths I tried, 4 gave me the best results at 10 us which is now a solid enough improvement to the looped version (15 us). Though I wonder how the results would change if my example computed a more dense Jacobian matrix.