FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Wrong results with higher order pullback -- chained if/else in `accum` #937

Open axsk opened 3 years ago

axsk commented 3 years ago

Here comes a strange one... This is the minimal working example I could find to get the erroneous result The loss function computes the gradients of some function (think of it as a mixture of eucl. distances) at each point given by the columns of x/data. I extract the gradients with the dfdx, ... statement and return it as result of the loss function.

In a later stage I now wish to optimize some model over this loss function, so I need the derivative of this spatial gradient wrt. to the model parameters. Here the model is just the identity and the loss does not depend on the parameter z.

However, when evaluating the loss via/inside the pullback function it returns another result.

Strangely this only happens when the sums involve the abs2 terms and I also need to subtract both, aa and bb, otherwise the results l1 and l2 are the same 😮

using Zygote

function mwe()
    a = ones(1) * 0
    b = ones(1) * .5
    data = ones(1)

    function loss()
        y, pb = Zygote.pullback(data) do x
            aa = sum(abs2, x .- a, dims=1)
            bb = sum(abs2, x .- b, dims=1)
            r = 1 .- aa .- bb
        end
        dfdx,  = pb(data)
        dfdx
    end

    l1 = sum(loss())

    l2, pb = Zygote.pullback() do 
        sum(loss())
    end

    @show l1, l2
    @assert l1 == l2
end
julia> mwe()
(l1, l2) = (-3.0, -2.0)
ERROR: AssertionError: l1 == l2

Any feedback is welcome :)

axsk commented 3 years ago

loss() = d/dx (1 - x^2 - (x-1/2)^2) = -4x + 1, and since x=data=1 we have that loss() = -3. The -2 returned by the outer pullback is wrong

DhairyaLGandhi commented 3 years ago

Huh interesting. This is a bit odd.

This might be because of an incorrect abs2 adjoint definition or erroneous forward pass rewriting or global lookup

cc @willtebbutt @mzgubic @oxinabox

mzgubic commented 3 years ago

I tried commenting out the adjoints for abs2 but that didn't solve it. It also breaks without arrays

julia> function mwe()
           a = 0
           b = .5
           data = 1.0

           function loss()
               y, pb = Zygote.pullback(data) do x
                   aa = abs2(x - a)
                   bb = abs2(x - b)
                   r = 1 - aa - bb
               end
               dfdx,  = pb(data)
               dfdx
           end

           l1 = sum(loss())

           l2, pb = Zygote.pullback() do 
               sum(loss())
           end

           @show l1, l2
           @assert l1 == l2
       end
mwe (generic function with 1 method)

julia> mwe()
(l1, l2) = (-3.0, -2.0)
ERROR: AssertionError: l1 == l2
Stacktrace:
 [1] mwe()
   @ Main ./REPL[9]:23
 [2] top-level scope
   @ REPL[10]:1
DhairyaLGandhi commented 3 years ago

I tried running a few experiments with explicit parameters to see if it was global lookups which didn't work either

axsk commented 3 years ago

Might this be related to perturbation confusion (c.f. https://github.com/JuliaDiff/ForwardDiff.jl/issues/83)?

axsk commented 3 years ago

Might this be related to perturbation confusion (c.f. JuliaDiff/ForwardDiff.jl#83)?

The example from the referenced paper (which mentions the problem for Forward-Mode) gradient(x->x * gradient(y->x+y, 1)[1], 1) == 1 works fine, so it's probably not that problem..

axsk commented 3 years ago

The outer gradients are computed wrong too (not in this example though, since there is no outer gradient).

DhairyaLGandhi commented 3 years ago

I'm going to have to jump deeper in here, will dig in

mcabbott commented 3 years ago

Following up on @mzgubic's simplification, here is a more minimal example, and one without a second derivative at all. The problem appears to be that Zygote is getting confused by the if statements in accum, and silently giving wrong answers, which is slightly disturbing.

julia> using Zygote

julia> function mmwe()
           f() = gradient(x -> 13*x + x, 17)[1]
           α = f()
           β, _ = pullback(f)
           @show α, β
           nothing
       end;

julia> mmwe()
(α, β) = (14, 2)

julia> let
          α = Zygote.accum(1,2)
          β, _ = pullback(Zygote.accum,1,2)
          @show α, β
       end;
(α, β) = (3, 2)

julia> Zygote.accum(x, y) =
         x === nothing ? y :
         # y === nothing ? x :   # this won't fix mwe(), it needs accum(::Float64, ::Missing)
         x + y

julia> mmwe()
(α, β) = (14, 14)