LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

[Nested AD] Incorrect gradient when taking a gradient over a gradient using StatefulLuxLayer #630

Closed MatsudaHaruki closed 2 months ago

MatsudaHaruki commented 2 months ago

Hello,

I needed to compute the gradient over a gradient and was able to run the following code without error using StatefulLuxLayer as described in the "Nested Automatic Differentiation" chapter of the documentation. However, when I verified the result with FiniteDiff.jl, as the documentation does, I found that the obtained result is invalid.

Code Example:

using Lux, Random, LinearAlgebra, Zygote, FiniteDiff

rng = Xoshiro(0)
model = Dense(3 => 3)    # very simple model
ps, st = Lux.setup(rng, model)

x = ones32(rng, 3, 1)

grads_1 = Zygote.gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = Zygote.gradient((z -> dot(v, z)) ∘ smodel, x) |> only
    return sum(w)
end |> only    # shows [0.0; 0.0; 0.0;;]

grads_2 = FiniteDiff.finite_difference_gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = Zygote.gradient((z -> dot(v, z)) ∘ smodel, x) |> only
    return sum(w)
end            # shows [-260.231; 114.75144; 21.052755;;]

Is this the expected behavior of this package? I would be glad to know if there is anything in the code that needs to be corrected.

Thank you.

MatsudaHaruki commented 2 months ago

The purpose of the avobe calculation was to compute the gradient of the VJP, and I also tried the following implementation using the recently released vector_jacobian_product function, but it gave similarly invalid results:

using Lux, Random, LinearAlgebra, Zygote, FiniteDiff

rng = Xoshiro(0)
model = Dense(3 => 3)    # very simple model
ps, st = Lux.setup(rng, model)

x = ones32(rng, 3, 1)

grads_1 = Zygote.gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = vector_jacobian_product(smodel, Lux.AutoZygote(), x, v)
    return sum(w)
end |> only    # shows [0.0; 0.0; 0.0;;]

grads_2 = FiniteDiff.finite_difference_gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    v = randn!(rng, similar(x))
    w = vector_jacobian_product(smodel, Lux.AutoZygote(), x, v)
    return sum(w)
end            # shows [-260.231; 114.75144; 21.052755;;]
avik-pal commented 2 months ago

You are computing the 2nd derivative of a linear operation, which is trivially zero.

using ForwardDiff

W = ps.weight
b = ps.bias

ForwardDiff.gradient(x) do x
    J = ForwardDiff.jacobian(x) do x
        return W * x .+ b
    end
    v = randn!(rng, similar(x))
    return sum(J' * v)
end
avik-pal commented 2 months ago

Try adding nonlinearity where the gradient becomes non-zero

using Lux, Random, LinearAlgebra, Zygote, FiniteDiff

rng = Xoshiro(0)
model = Dense(3 => 3, gelu)    # very simple model
ps, st = Lux.setup(rng, model)

x = ones32(rng, 3, 1)

v = randn!(rng, similar(x))

grads_1 = Zygote.gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    # v = randn!(rng, similar(x))
    w = vector_jacobian_product(smodel, Lux.AutoZygote(), x, v)
    return sum(w)
end |> only

grads_2 = FiniteDiff.finite_difference_gradient(x) do x
    smodel = StatefulLuxLayer(model, ps, st)
    # v = randn!(rng, similar(x))
    w = vector_jacobian_product(smodel, Lux.AutoZygote(), x, v)
    return sum(w)
end

# Simple ForwardDiff.jl test
using ForwardDiff

W = ps.weight
b = ps.bias

ForwardDiff.gradient(x) do x
    J = ForwardDiff.jacobian(x) do x
        return gelu.(W * x .+ b)
    end
    # v = randn!(rng, similar(x))
    return sum(J' * v)
end
MatsudaHaruki commented 2 months ago

@avik-pal -san, thanks for your quick reply! Ah yes, I didn't realize that the model presented as an example was too simple to reproduce the problem at hand, sorry for bothering you. I still have the same problem with my complicated nonlinear model, but I will take a closer look at it. It was a new insight for me that there are cases where the FiniteDiff and AD results are significantly inconsistent. Anyway, thank you very much for your kind reply!

avik-pal commented 2 months ago

No worries. As a rule of thumb, try to check with ForwardDiff which works in most cases (might be slow but that's fine).

avik-pal commented 2 months ago

Also test with FiniteDifferences.jl just be to sure, that tends to have a wider method selection, which helps with accuracy of the finite differences.

MatsudaHaruki commented 2 months ago

Thank you for sharing useful practices! It's very educational! I will try FiniteDifferences.jl. (which has a very similar name, but is a different package....😂)