FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Wrong gradient involving splatting of kwargs #1284

Closed niklasschmitz closed 2 years ago

niklasschmitz commented 2 years ago

Zygote (v0.6.43) currently gives a wrong gradient involving kwargs splatting. It seems to double-count a gradient contribution through implicit and explicit kwargs. Here's a small example:

f1(; kwargs...) = kwargs[:x]
f2(; kwargs...) = f1(; kwargs..., x=kwargs[:x])
f3(x) = f2(; x)
FiniteDiff.finite_difference_derivative(f3, 0.0) # 1.0
ForwardDiff.derivative(f3, 0.0) # 1.0
Zygote.gradient(f3, 0.0) # (2.0,)
oxinabox commented 2 years ago

I wonder if this is to do with kwarg[:x] showing up both in kwarg... and kwarg[:x] ...?

DhairyaLGandhi commented 2 years ago

That's precisely the case i believe

On Sun, Aug 7, 2022, 20:47 Frames Catherine White @.***> wrote:

I wonder if this is to do with kwarg[:x] showing up both in kwarg... and kwarg[:x] ...?

— Reply to this email directly, view it on GitHub https://github.com/FluxML/Zygote.jl/issues/1284#issuecomment-1207429650, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJOZVVJXBEIYZHVZUIGWV3LVX7HRLANCNFSM552IQVNA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

niklasschmitz commented 2 years ago

Interestingly, the issue persists even after eliminating this redundancy:

f1(; kwargs...) = kwargs[:x]
f2(; kwargs...) = f1(; x=kwargs[:x])
f3(x) = f2(; x)
FiniteDiff.finite_difference_derivative(f3, 0.0) # 1.0
ForwardDiff.derivative(f3, 0.0) # 1.0
Zygote.gradient(f3, 0.0) # (2.0,)
ToucheSir commented 2 years ago

This turned out to be a fun issue :scream: . In short, kwargs are represented as Pairs{..., NamedTuple}. This is an immutable type, but because Pairs <: AbstractDict it triggers the getindex adjoint for mutable dicts. Since objectid works by value and not by reference on immutable types, that means any set of keyword arguments with the same structure + arg types would accumulate to the same gradient.

Given all that, I can't help but wonder if this has been causing other mysterious bugs in the wild. Working on a PR that should hopefully be up soon.

ChrisRackauckas commented 2 years ago

This seems to cause downstream failures. See

function multiple_shoot(
    p::AbstractArray,
    ode_data::AbstractArray,
    tsteps::AbstractArray,
    ensembleprob::EnsembleProblem,
    ensemblealg::SciMLBase.BasicEnsembleAlgorithm,
    loss_function,
    continuity_loss,
    solver::DiffEqBase.AbstractODEAlgorithm,
    group_size::Integer;
    continuity_term::Real=100,
    kwargs...
)
    datasize = size(ode_data, 2)
    prob = ensembleprob.prob

    if group_size < 2 || group_size > datasize
        throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
    end

    @assert ndims(ode_data) == 3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
    @assert size(ode_data,2) == length(tsteps)
    @show kwargs
    @assert size(ode_data,3) == kwargs[:trajectories]

This then is called like:

function loss_multiple_shooting_ens(p)
    return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
                          loss_function, Tsit5(),
                          group_size; continuity_term,
                          trajectories,
                          abstol=1e-8, reltol=1e-6) # test solver kwargs
end
kwargs = Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}}(:trajectories => 2, :abstol => 1.0e-8, :reltol => 1.0e-6)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:
  [1] (::Zygote.var"#kwargs_literal_getindex_pullback#326"{Zygote.var"#1925#back#218"{Zygote.var"#back#217"{:trajectories, Zygote.Context{false}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}, Int64}}})(Δ::Nothing)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\qGFGD\src\lib\base.jl:165
  [2] Pullback
    @ c:\Users\accou\.julia\packages\DiffEqFlux\Em1Aj\src\multiple_shooting.jl:185 [inlined]

The line that errors is:

@assert size(ode_data,3) == kwargs[:trajectories]

My symbol is transformed into an integer and kwargs to nothing?

https://github.com/SciML/SciMLSensitivity.jl/runs/8028243222?check_suite_focus=true

using DiffEqFlux, OrdinaryDiffEq, Test

datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2], length=datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat=tsteps))

nn = FastChain((x, p) -> x .^ 3,
    FastDense(2, 16, tanh),
    FastDense(16, 2))
p_init = initial_params(nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat=tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p), u0, tspan, p_init)

function loss_function(data, pred)
    return sum(abs2, data - pred)
end

u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
function prob_func(prob, i, repeat)
    remake(prob, u0=u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node, prob_func=prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode, prob_func=prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg, trajectories=trajectories, saveat=tsteps))

group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
    return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
        loss_function, Tsit5(),
        group_size; continuity_term,
        trajectories,
        abstol=1e-8, reltol=1e-6) # test solver kwargs
end

res_ms_ensembles = DiffEqFlux.sciml_train(loss_multiple_shooting_ens, neuralode.p,
    ADAM(0.05), maxiters=300)
ToucheSir commented 2 years ago

Try https://github.com/FluxML/Zygote.jl/pull/1295 on for size.

ChrisRackauckas commented 2 years ago

That fixes it.