Closed SobhanMP closed 2 years ago
This might be related specifically to autojacvec=ZygoteVJP()
. This example returns zero gradients with both BacksolveAdjoint
and InterpolatingAdjoint
, while QuadratureAdjoint
throws an error.
using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity
function system!(du, u, p, t, controller)
α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0
y1, y2 = u
c1, c2 = controller(u, p)
y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
y2_prime = (β * c1 - γ * c2) * y1
@inbounds begin
du[1] = y1_prime
du[2] = y2_prime
end
end
function loss(params, prob, tsteps)
sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())
sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
return -Array(sol)[2, end] # second variable, last value, maximize
end
u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0
# dummy controller / parameters
controller = function (x,p)
σ.(x .* p)
end
θ = randn(2)
dudt!(du, u, p, t) = system!(du, u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)
loss(params) = loss(params, prob, tsteps)
Zygote.gradient(loss, θ)
Error:
julia> Zygote.gradient(loss, θ)
ERROR: MethodError: no method matching dudt!(::Vector{Float32}, ::Vector{Float64}, ::Float32)
Closest candidates are:
dudt!(::Any, ::Any, ::Any, ::Any) at REPL[28]:1
@IlyaOrson your example is just a bad choice of VJP. I added an informative error, and if you fix your example it's fine. https://github.com/SciML/DiffEqSensitivity.jl/pull/547
using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity
function system!(u, p, t, controller)
α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0
y1, y2 = u
c1, c2 = controller(u, p)
y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
y2_prime = (β * c1 - γ * c2) * y1
[y1_prime,y2_prime]
end
function loss(params, prob, tsteps)
sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())
sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
return -Array(sol)[2, end] # second variable, last value, maximize
end
u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0
controller = function (x,p)
σ.(x .* p)
end
θ = randn(2)
dudt!(u, p, t) = system!(u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)
loss(params) = loss(params, prob, tsteps)
Zygote.gradient(loss, θ)
The original issue here all has a simple cause. The issue is that your last saved value is not the end of the time interval. That's rather weird: why solve to 1.5
if you're only going to use values up to 1.4396123815716677
? So the code you probably want is:
using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim
u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort
append!(tsteps,1.5f0)
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
end
# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
saveat = tsteps, sensealg=InterpolatingAdjoint())
# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
prob_neuralode(u0, p)
end
function loss_neuralode(p)
pred = predict_neuralode(p)
pred = if size(pred, 2) == size(ode_data, 2) + 2
pred[:, 2:end-1]
else
pred
end
@show pred.t
loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
return loss
end
loss_neuralode(prob_neuralode.p)
@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all
which works.
That said, this should still get fixed to make it more robust.
Oh no, that's not the issue, it's even weirder.
using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim
u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort
tsteps = Float32.(tsteps)
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
end
# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
saveat = tsteps, sensealg=InterpolatingAdjoint())
# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
prob_neuralode(u0, p)
end
function loss_neuralode(p)
pred = predict_neuralode(p)
pred = if size(pred, 2) == size(ode_data, 2) + 2
pred[:, 2:end-1]
else
pred
end
loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
return loss
end
loss_neuralode(prob_neuralode.p)
@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all
The issue was that tsteps
was Float64
but t
was Float32
, so during the reverse pass it never exactly hit the tstop values because they were represented in different arithmetic. The workaround is tsteps = Float32.(tsteps)
, but the real way to fix this is to make the ODE solver automatically convert any saveat
or tstops
chosen into the typeof(t)
.
thanks for the rapid response and for the fast fix
I was trying to adapt this example with training points that are not evenly spaced and it seems that with the default
sensealg
, gradient is zero. I was told to open an issue in slack so here we are.would be a minimal example of what i was trying to do. either switching the definition of
tsteps
to the uniformly spaced one or chancing the sensealg toForwardDiffSensitivity
,BacksolveAdjoint
,QuadratureAdjoint
orReverseDiffAdjoint
solves the problem