EnzymeAD / Enzyme

High-performance automatic differentiation of LLVM and MLIR.
https://enzyme.mit.edu
Other
1.26k stars 106 forks source link

sqrt behavior at 0 #1295

Closed Bike closed 1 year ago

Bike commented 1 year ago

My group is getting started trying out Enzyme, and I noticed some unintuitive behavior on square roots: the derivative of sqrt at zero seems to be specially computed as zero. __enzyme_autodiff(my_sqrt, x) ends up with an fcmp and select to essentially do (x = 0.0) ? 0.0 : 1/(sqrt(2.0*x)), instead of what I expected, 1/(sqrt(2.0*x)).

This is, unless I'm missing something, mathematically incorrect - the derivative at zero should be undefined as the function is not continuous there (at least on reals), or positive infinity if you take the limit from above. But I'm less concerned about that and more about the decrease in efficiency from the branch. It seems to slow things around 40% in a very quick benchmark, versus the manual 1/sqrt(2.0*x). Here's the explorer showing a ucomisd -> je sequence.

I am not sure that this is a bug since it appears to be a deliberate design decision, with similar logic apparent e.g. in this benchmark. But I am at least interested in knowing why it was designed this way, and whether there is any way to configure this away. I couldn't find any mention of this behavior on the website, the arxiv preprint, or the code. We would like to use Enzyme rather than the symbolic differentiation we are doing now but this is a noticeable decrease in performance for us.

ZuseZ4 commented 1 year ago

Hi @Bike, just for your question about the documentation, if you are interested into the rules which we have, this is the base place to check: https://github.com/EnzymeAD/Enzyme/blob/3c0014a1f99fd573bbb21fb69b6d699a8d47a679/enzyme/Enzyme/InstructionDerivatives.td#L785

For older or more complex rules we also have https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/Enzyme/AdjointGenerator.h, but they might be a lot harder to understand.

wsmoses commented 1 year ago

If I recall correctly that is there for correctness around edge points.

Incidentally we may now have a better way of handling this now (the checkeddiv inst), which is conditionally activated upon a flag saying that is necessary.

As Manuel mentioned our derivative rules for things like this are simply specified in that file, which if you'd like to get into Enzyme dev, wouldn't be hard to open a PR and update!

wsmoses commented 1 year ago

Also that test which you mention is actually the derivative generated by tapenade, another AD tool, which also has similar handling of these points.

tgymnich commented 1 year ago

@Bike some of the reasoning behind this is outlined in this paper: https://proceedings.neurips.cc/paper/2020/file/4aaa76178f8567e05c8e8295c96171d8-Paper.pdf

Bike commented 1 year ago

@tgymnich Thank you, I will read. I suspect that the considerations for ML are not so relevant for us (we're doing physics simulation) but I'm glad to know the ideas.

@ZuseZ4 @wsmoses Thanks for the tip. I patched this code to see how things would go without the zero test, and after optimization it ends up just as fast as my manual version. I assume Enzyme would not accept a PR to remove the test as that would result in bad behavior on the edge. As a more configurable solution for us, is it perhaps possible to manually specify the derivative of a function? Then we could have our code use my_sqrt, and tell Enzyme that the derivative of my_sqrt is the untested division.

tgymnich commented 1 year ago

@Bike This functionality could be introduced behind some kind of flag. Given that the changes are not too invasive and some notion of completeness (what about FDiv and so on).

wsmoses commented 1 year ago

@tgymnich we already have a flag already within the checkedMul/checkedDiv, which I presume we could just use here.

Bike commented 1 year ago

Since the problem with sqrt is also due to division by zero, could the sqrt pattern just be changed to use CheckedDiv, rather than using the flag separately? There is a slight difference in that the existing sqrt pattern uses a ueq comparison rather than CheckedDiv's oeq - I don't know if that's important, depends on what a NaN input should do. I don't know how to enable or disable this EnzymeStrongZero flag, but I could try patching this and submitting a PR.

tgymnich commented 1 year ago

@Bike you can enable the flag depending on the pass manager you use like so:

  1. clang legacy new PM: -fpass-plugin=/opt/compiler-explorer/main/ClangEnzyme-XX.so -Xclang -load -Xclang /opt/compiler-explorer/main/ClangEnzyme-XX.so -mllvm -enzyme-strong-zero
  2. clang new legacy PM: -Xclang -load -Xclang /opt/compiler-explorer/main/ClangEnzyme-XX.so -mllvm -enzyme-strong-zero
jedbrown commented 1 year ago

Just a note that turning on trapping floating point is an important debugging tool so I'd prefer not to see fp exceptions for code that is considered unexceptional. I don't know if this case is, but the way it's being talked about up above suggests maybe.

wsmoses commented 1 year ago

As a side note that you may be interested in @jedbrown a few weeks back we added support for running custom code on all the intermediate derivative values. For example, performing a nan-check (and then getting a backtrace to what triggered then nan).

jedbrown commented 1 year ago

Cool, is that for forward or reverse or both?

wsmoses commented 1 year ago

Both! Here's the (undocumented) Julia flag for enabling the nan checker as an example: https://github.com/EnzymeAD/Enzyme.jl/blob/02715a8bbac185342fe427f0090a8893d3a8af1d/src/compiler.jl#L5854

If you have other use cases and/or want to play with (from whatever input language, lmk)

Bike commented 1 year ago

@tgymnich Oh, thanks for the tip. I'd love to be able to use the new pass manager. But I'm a bit confused - the FAQ linked in that other issue (https://enzyme.mit.edu/getting_started/Faq/#opt-cant-find--enzyme-option) says the opposite, that -fpass-plugin is the new one. But in any case it doesn't seem to work: clang++-15 -fpass-plugin=/path/to/my/LLVMEnzyme-15.so enzyme-test.cc -o enzyme-test gets me a linker error about undefined reference to `__enzyme_autodiff(void*, double)', so presumably Enzyme is not being run. (I can run things through opt manually but being able to do everything in one go would be nice.)

wsmoses commented 1 year ago

@Bike uou should use ClangEnzyme not LLVMEnzyme for loading into clang

tgymnich commented 1 year ago

@Bike my bad I mixed up the pass managers. It should be the other way around.

Bike commented 1 year ago

Thank you, I got it working. The -enzyme-strong-zero flag with CheckedDiv does not seem to be working as I would expect, so I'll have to play around with it.

Bike commented 1 year ago

I must not be understanding CheckedDiv. I have InstructionDerivatives set up so that sqrt does a CheckedDiv instead of an FDiv (or manual FCmp). I thought that if I passed the -mllvm -enzyme-strong-zero that would make it generate through the condition in checkedDiv and I'd get an fcmp, but it's not doing that, as if EnzymeStrongZero is false. Do I have to build Enzyme in a special way, or something? The only other thing set up to have CheckedDiv is atan, but that can only be a divide by zero if the argument is imaginary.

Bike commented 1 year ago

Okay - figured it out - checkedDiv does res = Builder2.CreateSelect(Builder2.CreateFCmpOEQ(idiff, zero), zero, res);. But idiff is the dividend, not the divisor. Am I missing something or is that incorrect? If I change the oeq to compare pres (the divisor) to zero I get the behavior I expect, i.e. dsqrt(0) = 0 instead of inf.

tgymnich commented 1 year ago

@Bike this does seem like a bug to me. I guess the idea here is to preserve the semantics of the primal. The purpose here is not to check for div by zero.

wsmoses commented 1 year ago

So the reason checkedmul/checkeddiv operate as they do is for a distinct reason: If a value is inactive (perhaps unable to be proven by activity analysis), its derivative must be zero. Normally, this is fine for most update rules. However, in the case of a multiply or divide by certain values, the zero of the derivative is not "strong" enough to overcome the computation its used in (hence the arbitrary naming of EnzymeStrongZero).

For example, res = x infinity would lead to dres = dx infinity. If dx was zero as we didn't want its derivative (or other activity reasons), we would instead sadly compute the total derivative of nan, instead of the intended derivative of zero.

Sqrt(x) is a weird point, and at this point I don't remember the origin of the current behavior. Tapenade has the same behavior so I looked at their source code to find:

  // Primal is:  X=SQRT(Y)
  // Tangent is: if (Y==0) {XD=0} else {XD=YD/(2*SQRT(Y))}
  //   Strictly speaking, this is wrong, but we prefer a wrong 0.0 to a right NaN !
  //   TODO: install a runtime warning

The checkedMul/Div have strong motivations to me, and both the numerator and divisor checks resolve the nan issue (albeit in different ways) at the indeterminite point.

I don't recall the historical justification for the divisor condition, so I'll tag @martinjm97 and @jhueckelheim @jedbrown if they have thoughts that may be helpful.

wsmoses commented 1 year ago

Talking with @martinjm97 so one case that the old behavior solves is:

res = 0 * sqrt(x) at x = 0

Clearly this is just equivalent to 0 [say via finite diff/etc], but even with the checkedMul (using forward mode):

dx = 1
dsqrt = if dx == 0 ? 0 : dx / (2*sqrt(x)) // evaluates to inf

deriative of mul is
dres = if dsqrt == 0 ? 0 : (literal 0 from the input expression) * dsqrt // evaluates to 0 * inf -> nan
Bike commented 1 year ago

Alright, I misunderstood what y'all meant about the checkedDiv flag then. ah well. The paper linked earlier mentions sqrt but only to say that PyTorch and TensorFlow treat it conservatively (meaning the derivative at zero is inf). But I don't see it mention why you'd want it to be zero instead.

wsmoses commented 1 year ago

Yeah @martinjm97 just had a long conversation about this (he'll post shortly), but the tl;dr is that it is required to get the correct behavior when differentiating y * sqrt(x) when both y and x are zero. Otherwise it breaks the ability to decompose into separate chain rules (since it would propagate nan's.

martinjm97 commented 1 year ago

Hi @Bike,

Happy to hop in and give some thoughts. So I believe three options under consideration are: (1) dx/(2sqrt(x)) (2) if x == 0 then 0, else dx/(2sqrt(x)) (3) if dx == 0 then 0, else dx/(2sqrt(x))

I think different approaches are reasonable depending on programmer intent: (1) "I care about the whole directional derivative in that direction even even if it is 0 in the dx direction" (2) "I want to avoid NaNs at all cost, especially because 0 * NaN = NaN, so it might erroneously destroy the whole computation" (3) "I don't want to include the infinitesimal dx if it is 0"

I think there are cases where each of these make sense: (1) sqrt(x) at dx = 0 and x = 0 divides by zero! This matches the math specification since the left limit does not exist at 0! (2) I want to differentiate 0 * sqrt(x) at x = 0 and since it is mathematically zero, I want to get 0! (3) y * sqrt(x) where I calculate the total derivative, but I only care about when dy is 1.

I guess this isn't much of an answer, but more of a choose your own adventure, but I don't yet see an ideal resolution. I'm personally partial (pun intended) to (3). I think of dx being 0 as indicating that I don't care about that part of the computation, and this approach returns NaN on cases that I think are important such as sqrt(x) at dx = 1 and x = 0.

jedbrown commented 1 year ago

I feel a bit uncomfortable about this from the NeurIPS paper shared above:

image

So what happens if we apply this new rule to sqrt(mult(x, x))? Note that mult(x, 0) and mult(x, x) have the same value and derivative at x=0. Then we get 2*0*(+inf) = NaN with strategy (1) above; with their intensional derivatives, we get 0, but $\sqrt{x^2} = x$ has derivative 1. (Small adjustments can make the correct derivative be any real number, but intensional derivatives will keep reporting 0.) This sort of problem seems unavoidable without non-local reasoning (taking limits, L'Hopital's rule) but it feels disconcerting to recommend a method that is silently wrong as being more principled than NaN.

tgymnich commented 1 year ago

$\sqrt{x}^2$ would be a better example, since the derivative of $\sqrt{x^2}$ is $\frac{x}{\sqrt{x^2}}$ which can be written as $\frac{x}{|x|}$.

The theory appears solid, its FP math that is lacking. Under floating point math e.g. without -fast-math neither $\sqrt{x^2}$ nor $\sqrt{x}^2$ are equivalent to $|x|$ or $x$ respectively. Running the whole thing with -fast-math works just fine.

jedbrown commented 1 year ago

Nothing puts @simonbyrne at ease like code that is only correct with -ffast-math. :joy:

wonyeol commented 1 year ago

@martinjm97 brought me into this thread, and I would like to make small clarifications about the NeurIPS'20 paper discussed above, which I co-authored. I don't fully understand the context of this thread, but I hope my clarifications could be helpful to the entire discussion. Please let me know if you have any other questions or need any further clarifications :)

The main messages of the paper are the following:

The following are clarifications on some of the above discussion.

@jedbrown: I feel a bit uncomfortable about this from the NeurIPS paper shared above. [...] So what happens if we apply this new rule to sqrt(mult(x, x))?

This program represents the mathematical function $\sqrt{x^2} = |x|$ over $\mathbb{R}$, so it is in fact non-differentiable at $x=0$ and the standard derivative of the program does not exist at all for $x=0$. As @jedbrown mentioned, if AD uses $dsqrt(0)=0$, then AD computes the following function when applied to this program: $0$ for $x=0$, $1$ for $x>0$, and $-1$ for $x<0$. This is an intensional derivative of the function $|x|$, and it is identical to the standard derivative of $|x|$ almost everywhere (in this case, except for $x=0$). Again, the standard derivative of $|x|$ does not exist for $x=0$, so this output of AD is not wrong at all.

@jedbrown: This sort of problem seems unavoidable without non-local reasoning (taking limits, L'Hopital's rule) but it feels disconcerting to recommend a method that is silently wrong as being more principled than NaN.

I would like to note that it is unavoidable to prevent AD from being silently wrong, regardless of the "derivatives" that AD uses for non-differentiable functions (e.g., sqrt or relu). A representative example is $relu(x) - relu(-x)$. This program is mathematically equivalent to the identity function $x$, but using $drelu(x)=0$ (or any constant $c \neq 1/2$) makes AD compute $0$ (or $2c \neq 1$) as an output for $x=0$, and this means AD is silently wrong for this particular input. Since sqrt is non-differentiable at $x=0$ (as its derivative diverges at 0), a similar thing (i.e., AD being silently wrong) could happen as well---the above example, however, is not this case.

Here I am not arguing that AD should use an intensional derivative for each primitive function (such as $dsqrt(0)=0$, rather than $dsqrt(0)=\infty$ as in TensorFlow and PyTorch). I am saying that if AD uses an intensional derivative (rather than other choices like $dsqrt(0)=\infty$), we can guarantee that AD is correct for almost all inputs (i.e., all but measure-zero inputs)---so using intensional derivatives is in fact one principled way of handling non-differentiable functions. But as @tgymnich mentioned, we are working with floats not reals in practice, so the above result might not hold in practical settings. For instance, measure-zero inputs can become important---we have a recent paper on this aspect (https://arxiv.org/abs/2301.13370). Also, there can be other reasons why it is more desirable to use, for instance, $dsqrt(0)=\infty$ rather than $dsqrt(0)=0$.

jedbrown commented 1 year ago

Thanks @wonyeol for your informative comment.