JuliaDiff / DiffRules.jl

A simple shared suite of common derivative definitions
Other
74 stars 38 forks source link

Remove rules for `conj`, `adjoint`, and `transpose` #67

Closed devmotion closed 2 years ago

devmotion commented 2 years ago

This PR removes the rules for conj, adjoint, and transpose since they cause problems with ReverseDiff (https://github.com/JuliaDiff/DiffRules.jl/pull/54 broke some tests for adjoint; see https://github.com/JuliaDiff/ReverseDiff.jl/issues/183#issuecomment-921287528 and the other comments there for details) and the default fallbacks in https://github.com/JuliaLang/julia/blob/c5f348726cebbe55e169d4d62225c2b1e587f497/base/number.jl#L211-L213 should be sufficient (similar to the discussion about identity in https://github.com/JuliaDiff/DiffRules.jl/pull/64).

I checked locally that the ReverseDiff issues are fixed by this PR.

Edit: I also added ReverseDiff to the integration tests.

codecov-commenter commented 2 years ago

Codecov Report

Merging #67 (8dca5df) into master (d79d2d9) will decrease coverage by 0.14%. The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #67      +/-   ##
==========================================
- Coverage   93.89%   93.75%   -0.15%     
==========================================
  Files           2        2              
  Lines         131      128       -3     
==========================================
- Hits          123      120       -3     
  Misses          8        8              
Impacted Files Coverage Δ
src/rules.jl 99.15% <ø> (-0.03%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d79d2d9...8dca5df. Read the comment docs.

devmotion commented 2 years ago

ForwardDiff test errors seem unrelated and the same that https://github.com/JuliaDiff/ForwardDiff.jl/pull/544 tries to address.

devmotion commented 2 years ago

ModelingToolkit errors are the same as on the master branch: https://github.com/SciML/ModelingToolkit.jl/runs/3563293020

mcabbott commented 2 years ago

Is there a short explanation of why these are wrong? Or why they lead ReverseDiff to the wrong answer?

I tried to read the thread but don't follow the details of its internals. There are no complex numbers involved. The rule does not apply to arrays, hence Adjoint etc. Why does it stumble on a function acting on real numbers which has derivative 1? It doesn't in isolation:

julia> ReverseDiff.gradient(x -> transpose(x[1])^2, [3.0])
1-element Vector{Float64}:
 6.0

yet you say "but the [rule] for transpose lead to a TrackedReal with a derivative of zero!". But why?

ForwardDiff used to define rules for these (or at least for conj) which I presume would need to be put back with a correlated PR, if they are removed here.

devmotion commented 2 years ago

The short anwser is https://github.com/JuliaDiff/ReverseDiff.jl/issues/183#issuecomment-921287528 😛I guess my explanations were not completely clear though. I did not refer to derivatives of functions that involve transpose(::Real) (or one of the other functions). I wanted to refer to transpose(::TrackedReal) (and the other functions) which, as shown in the linked comment, are called when you retrieve or set values of an Transpose or Adjoint of tracked numbers. These are defined in https://github.com/JuliaDiff/ReverseDiff.jl/blob/01041c8e8237ed42f6414c6fe0f6e6b12162b6ac/src/derivatives/scalars.jl#L7. Thus these calls end up here https://github.com/JuliaDiff/ReverseDiff.jl/blob/16b35963234c398fc3a1eb42efab8516eac466e1/src/macros.jl#L84 which returns a TrackedReal with derivative 0.

I don't think this PR is problematic for ForwardDiff: for all these functions you want f(x::Dual) = x, which is exactly what the default definitions in Julia base are. Also the integration tests did not reveal any test failures (apart from the known random issue).

So in my opinion these rules are not helpful but actually cause problems in downstream packages, and hence it would be good to remove them (basically like the already removed rule for identity).

ChrisRackauckas commented 2 years ago

@shashi you really need to prioritize fixing MTK master. This has gone on for way too long.

mcabbott commented 2 years ago

I wanted to refer to transpose(::TrackedReal)

But how does my example not involve this? Presumably within a gradient call you will not get transpose(::Real) since the gradient is tracked? Like I said, I don't know the internals of the package. But it seems concerning if gradient 1 of a scalar function leads to gradient zero. Why doesn't this occur more widely?

devmotion commented 2 years ago

When you compute ReverseDiff.gradient you have an additional reverse pass that accumulates the partial derivatives correctly: https://github.com/JuliaDiff/ReverseDiff.jl/blob/01041c8e8237ed42f6414c6fe0f6e6b12162b6ac/src/derivatives/scalars.jl#L47

mcabbott commented 2 years ago

This still seems pretty weird to me. Is it clear to you when this is and isn't going to be triggered? Is it clear what properties other functions would have to have to trigger it?

julia> ReverseDiff.gradient(x -> (x' .+ x)[1], [3.0])  # wrong
1-element Vector{Float64}:
 1.0

julia> ReverseDiff.gradient(x -> (x' + x)[1], [3.0])  # maybe no broadcasting?
1-element Vector{Float64}:
 1.0

julia> ReverseDiff.gradient(x -> (x' + x')[1], [3.0])  # fine
1-element Vector{Float64}:
 2.0

 julia> ReverseDiff.gradient(x -> (x' + x')[1] + x[1], [3.0])  # fine, mixing adjoint & not
1-element Vector{Float64}:
 3.0

julia> ReverseDiff.gradient(x -> (x')[1] + x[1], [3.0])  # fine
1-element Vector{Float64}:
 2.0
devmotion commented 2 years ago

I tried to explain this in the linked issue 🤷‍♂️ The problem occurs when we call increment_deriv! or decrement_deriv! in the reverse back with Adjoint or Transpose of TrackedReal: in this case the default indexing methods in base will call adjoint/transpose for elements of the wrapped array, and hence the accumulated derivatives become incorrect if we don't use the default adjoint(::TrackedReal) = x in base but instead the DiffRules-based method that returns a TrackedReal with x.deriv = 0. In some cases, eg. the example f(x) = sum(x' * x) from the issue, the gradients are still correct even though Adjoint or Transpose are involved since these functions use a special path (eg ReverseDiff contains special implementations for *(::Adjoint{<:TrackedReal}, ::AbstractMatrix) etc.).

Fixing indexing of Adjoint{<:TrackedReal} etc as discussed in the issue is not sufficient though: eg matrix division of TrackedArray still fails. It falls back to a division of Adjoint{<:TrackedReal}, and I assume at some point it calls adjoint etc. which is still not correct for TrackedReal. Since I think this PR is the more general and cleaner fix and I did not have more time, I did not continue debugging this error in more detail.

mcabbott commented 2 years ago

and hence the accumulated derivatives become incorrect

Then it seems like the claim is that any array indexing which calls a function, for which a gradient is defined, may lead to problems? I tried a bit to trigger this with MappedArrays.jl but haven't managed to. Is this the zero gradient you refer to?

julia> sqrt(ReverseDiff.TrackedReal(2.0, 3.0))
TrackedReal<3vn>(1.4142135623730951, 0.0, ---, ---)

julia> transpose(ReverseDiff.TrackedReal(2.0, 3.0))
TrackedReal<2P5>(2.0, 0.0, ---, ---)

I don't oppose the quick fix, I'm just slightly disturbed by the ability for a mathematically not wrong rule to silently break something deep inside how it accumulates gradients. And wonder where else that can surface.

devmotion commented 2 years ago

I haven't actively looked for any similar issues, I can try MappedArrays as well when I'm back on my computer. I'm by far not a ReverseDiff expert but, of course, there are bugs and issues about incorrect gradients (as in every AD system I assume), so I wouldn't be surprised if it can be reproduced in a similar setting 🤷‍♂️

I disagree though, I don't think this is a quick fix. In my opinion this is the correct thing to do, regardless of ReverseDiff. The default definitions in base already do the correct thing for dual and tracked numbers, so the definitions are not needed - and as we see (and already saw with identity) they do not even not help downstream packages but actually break stuff.

shashi commented 2 years ago

@ChrisRackauckas i'm fixing it.

mcabbott commented 2 years ago

base already do the correct thing for dual and tracked numbers, so

This seems to be the question. Is the sole purpose of this package is to provide rules for scalar operator-overloading AD? Is a rule for max or ifelse then forbidden because these should work without?

Or does "Many differentiation methods" in the readme include other things which may never see the Base definitions? Such as symbolic differentiation. Or generating rules for broadcasting.

devmotion commented 2 years ago

I don't know, I based my opinion on the integration tests and the removal of identity :smile: Maybe the rules are useful at some point, maybe it would be useful to be able to distinguish and pick different subsets/types of rules :man_shrugging:

But I guess this is something that should be discussed in a different issue. Since you said you're not opposed to the PR, I guess you're fine with me going ahead and merging it?

devmotion commented 2 years ago

FYI I checked and with this PR all your examples above return the correct gradient.