EnzymeAD / www

🌎 Enzyme project home page
https://enzyme.mit.edu
0 stars 10 forks source link

Document Split Mode #5

Open wsmoses opened 2 years ago

wsmoses commented 2 years ago

Some examples / text from a conversation with @harshithamenon I'm copying here so we don't lose/forget about, and we can reincorporate into docs: https://fwd.gymni.ch/WfDU2f

#include <math.h>
#include <stdio.h>

#define nullptr ((void*)0)

// f(x, y) -> { sin(x), sin(x)^2 + y }
void mysimulation(double *data) {
    data[0] = sin(data[0]);
    data[1] += data[0] * data[0];
}

void __enzyme_autodiff(void (*)(double *), ...);
void* __enzyme_augmentfwd(void (*)(double *), ...);
void __enzyme_reverse(void (*)(double *), ...);

int enzyme_nofree;
int enzyme_tape;

int main()
{
    double data[2] = {2.0, 3.0};
    mysimulation(data);
    printf("vanilla run %f %f\n", data[0], data[1]);

    data[0] = 2.0; data[1] = 3.0;

    // Take the derivative wrt x, on fresh data.
    double d_data[2] = {1.0, 0.0};
    __enzyme_autodiff(mysimulation, data, d_data);
    printf("result run %f %f | derivative %f %f\n", data[0], data[1], d_data[0], d_data[1]);

    d_data[0] = 1.0; d_data[1] = 0.0;
    __enzyme_autodiff(mysimulation, data, d_data);
    printf("rerun derivative without resetting data result %f %f | derivative %f %f\n", data[0], data[1], d_data[0], d_data[1]);

    data[0] = 2.0; data[1] = 3.0;

    // Compute the original result, and also save any information
    // needed for the derivative. We don't need the shadow (derivative)
    // array here so we can pass nullptr.
    void* tape = __enzyme_augmentfwd(mysimulation, data, nullptr);
    printf("augmented run %f %f\n", data[0], data[1]);

    // Just for extra fun, overwrite the data
    data[0] = 0.0/0.0;
    data[1] = 0.0/0.0;

    // Compute the original result, and also save any information
    // needed for the derivative. We don't need the primal (original)
    // array here so we can pass nullptr.
    d_data[0] = 1.0; d_data[1] = 0.0;
    __enzyme_reverse(mysimulation, enzyme_nofree, nullptr, d_data, tape);
    printf("rerun derivative without resetting data result %f %f | derivative %f %f\n", data[0], data[1], d_data[0], d_data[1]);

    // We can in fact compute the derivative now a second time, using the same cached
    // data and again get the same (correct) result. Now without enzyme_nofree, we'll
    // free the preserved data.
    d_data[0] = 1.0; d_data[1] = 0.0;
    __enzyme_reverse(mysimulation, nullptr, d_data, tape);
    printf("second run derivative without resetting data result %f %f | derivative %f %f\n", data[0], data[1], d_data[0], d_data[1]);

    return 0;
}
I’m not sure this is exactly what you’re thinking, but basically what happens in that snippet is we have a “simulation” (aka overwrite data as f(x, y) -> { sin(x), sin(x)^2 + y }).
If you first run the usual mode it will compute the original result and the derivative. Thus if you ask for the derivative again, you’d get the derivative at the second time step.
If instead you used the __enzyme_augmentfwd function (aka preserve values you need to compute the derivative at a later time), you can arbitrarily overwrite the original data (e.g. here setting it to nan), and successfully get the correct derivative. In fact you can use that same “tape” or “cache” a second time, again again get that same correct value.
I’m not sure this is 100% the same thing as you’re thinking since we explicitly only preserve the values that are needed to compute the derivative rather than entirely restart the simulation (its desirable to preserve less)
What would happen if the application exits and restarts from some state stored in a file. Wouldn't you need some hook into the application to specify how to reconstruct the tape?

It’s just data you could save and load the X bytes it needs to/from a file (there’s a separate get size of tape function). You can also specify Enzyme to store and/or load the tape data at a specific pointer
There’s a separate get size of tape function. You can also specify Enzyme to store and/or load the tape data at a specific pointer

Here’s the libCEED (forward mode version of this) code which just uses some random memory location provided by libCEED: https://github.com/CEED/libCEED/blob/99e8d5bed2f93906167219e46a292a0381310e8b/examples/solids/qfunctions/finite-strain-neo-hookean-initial-ad.h#L175
CEED_QFUNCTION_HELPER void S_fwd(double *S, double *E, const double lambda,
                                 const double mu, double *tape) {
  int tape_bytes = __enzyme_augmentsize((void *)computeS, enzyme_dup, enzyme_dup,
                                        enzyme_const, enzyme_const);
  __enzyme_augmentfwd((void *)computeS, enzyme_allocated, tape_bytes,
                      enzyme_tape, tape, enzyme_nofree, S, (double *)NULL, E, (double *)NULL,
                      enzyme_const, lambda, enzyme_const, mu);
}

CEED_QFUNCTION_HELPER void grad_S(double *dS, double *dE, const double lambda,
                                  const double mu, const double *tape) {
  int tape_bytes = __enzyme_augmentsize((void *)computeS, enzyme_dup, enzyme_dup,
                                        enzyme_const, enzyme_const);
  __enzyme_fwdsplit((void *)computeS, enzyme_allocated, tape_bytes,
                    enzyme_tape, tape, (double *)NULL, dS, (double *)NULL, dE,
                    enzyme_const, lambda, enzyme_const, mu);
}
wsmoses commented 1 year ago

@samuelpmishLLNL perhaps also useful to add to docs if you have cycles?

samuelpmishLLNL commented 1 year ago

I'm all for improving documentation, but I don't really understand what's going on in this discussion-- so, I'm probably not the right person to write up how it works!