Closed niklasschmitz closed 2 years ago
I wonder if this is to do with kwarg[:x]
showing up both in kwarg...
and kwarg[:x]
...?
That's precisely the case i believe
On Sun, Aug 7, 2022, 20:47 Frames Catherine White @.***> wrote:
I wonder if this is to do with kwarg[:x] showing up both in kwarg... and kwarg[:x] ...?
— Reply to this email directly, view it on GitHub https://github.com/FluxML/Zygote.jl/issues/1284#issuecomment-1207429650, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJOZVVJXBEIYZHVZUIGWV3LVX7HRLANCNFSM552IQVNA . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Interestingly, the issue persists even after eliminating this redundancy:
f1(; kwargs...) = kwargs[:x]
f2(; kwargs...) = f1(; x=kwargs[:x])
f3(x) = f2(; x)
FiniteDiff.finite_difference_derivative(f3, 0.0) # 1.0
ForwardDiff.derivative(f3, 0.0) # 1.0
Zygote.gradient(f3, 0.0) # (2.0,)
This turned out to be a fun issue :scream: . In short, kwargs are represented as Pairs{..., NamedTuple}
. This is an immutable type, but because Pairs <: AbstractDict
it triggers the getindex adjoint for mutable dicts. Since objectid
works by value and not by reference on immutable types, that means any set of keyword arguments with the same structure + arg types would accumulate to the same gradient.
Given all that, I can't help but wonder if this has been causing other mysterious bugs in the wild. Working on a PR that should hopefully be up soon.
This seems to cause downstream failures. See
function multiple_shoot(
p::AbstractArray,
ode_data::AbstractArray,
tsteps::AbstractArray,
ensembleprob::EnsembleProblem,
ensemblealg::SciMLBase.BasicEnsembleAlgorithm,
loss_function,
continuity_loss,
solver::DiffEqBase.AbstractODEAlgorithm,
group_size::Integer;
continuity_term::Real=100,
kwargs...
)
datasize = size(ode_data, 2)
prob = ensembleprob.prob
if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
end
@assert ndims(ode_data) == 3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
@assert size(ode_data,2) == length(tsteps)
@show kwargs
@assert size(ode_data,3) == kwargs[:trajectories]
This then is called like:
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(),
group_size; continuity_term,
trajectories,
abstol=1e-8, reltol=1e-6) # test solver kwargs
end
kwargs = Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}}(:trajectories => 2, :abstol => 1.0e-8, :reltol => 1.0e-6)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:
[1] (::Zygote.var"#kwargs_literal_getindex_pullback#326"{Zygote.var"#1925#back#218"{Zygote.var"#back#217"{:trajectories, Zygote.Context{false}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}, Int64}}})(Δ::Nothing)
@ Zygote C:\Users\accou\.julia\packages\Zygote\qGFGD\src\lib\base.jl:165
[2] Pullback
@ c:\Users\accou\.julia\packages\DiffEqFlux\Em1Aj\src\multiple_shooting.jl:185 [inlined]
The line that errors is:
@assert size(ode_data,3) == kwargs[:trajectories]
My symbol is transformed into an integer and kwargs to nothing
?
https://github.com/SciML/SciMLSensitivity.jl/runs/8028243222?check_suite_focus=true
using DiffEqFlux, OrdinaryDiffEq, Test
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2], length=datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat=tsteps))
nn = FastChain((x, p) -> x .^ 3,
FastDense(2, 16, tanh),
FastDense(16, 2))
p_init = initial_params(nn)
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat=tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p), u0, tspan, p_init)
function loss_function(data, pred)
return sum(abs2, data - pred)
end
u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
function prob_func(prob, i, repeat)
remake(prob, u0=u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node, prob_func=prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode, prob_func=prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg, trajectories=trajectories, saveat=tsteps))
group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(),
group_size; continuity_term,
trajectories,
abstol=1e-8, reltol=1e-6) # test solver kwargs
end
res_ms_ensembles = DiffEqFlux.sciml_train(loss_multiple_shooting_ens, neuralode.p,
ADAM(0.05), maxiters=300)
Try https://github.com/FluxML/Zygote.jl/pull/1295 on for size.
That fixes it.
Zygote (v0.6.43) currently gives a wrong gradient involving kwargs splatting. It seems to double-count a gradient contribution through implicit and explicit kwargs. Here's a small example: