Closed SobhanMP closed 2 years ago
This might be related specifically to autojacvec=ZygoteVJP()
. This example returns zero gradients with both BacksolveAdjoint
and InterpolatingAdjoint
, while QuadratureAdjoint
throws an error.
using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity
function system!(du, u, p, t, controller)
α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0
y1, y2 = u
c1, c2 = controller(u, p)
y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
y2_prime = (β * c1 - γ * c2) * y1
@inbounds begin
du[1] = y1_prime
du[2] = y2_prime
function loss(params, prob, tsteps)
sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
return -Array(sol)[2, end] # second variable, last value, maximize
u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0
# dummy controller / parameters
controller = function (x,p)
σ.(x .* p)
θ = randn(2)
dudt!(du, u, p, t) = system!(du, u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)
loss(params) = loss(params, prob, tsteps)
Zygote.gradient(loss, θ)
julia> Zygote.gradient(loss, θ)
ERROR: MethodError: no method matching dudt!(::Vector{Float32}, ::Vector{Float64}, ::Float32)
Closest candidates are:
dudt!(::Any, ::Any, ::Any, ::Any) at REPL[28]:1
@IlyaOrson your example is just a bad choice of VJP. I added an informative error, and if you fix your example it's fine.
using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity
function system!(u, p, t, controller)
α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0
y1, y2 = u
c1, c2 = controller(u, p)
y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
y2_prime = (β * c1 - γ * c2) * y1
function loss(params, prob, tsteps)
sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
return -Array(sol)[2, end] # second variable, last value, maximize
u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0
controller = function (x,p)
σ.(x .* p)
θ = randn(2)
dudt!(u, p, t) = system!(u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)
loss(params) = loss(params, prob, tsteps)
Zygote.gradient(loss, θ)
The original issue here all has a simple cause. The issue is that your last saved value is not the end of the time interval. That's rather weird: why solve to 1.5
if you're only going to use values up to 1.4396123815716677
? So the code you probably want is:
using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim
u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
saveat = tsteps, sensealg=InterpolatingAdjoint())
# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
prob_neuralode(u0, p)
function loss_neuralode(p)
pred = predict_neuralode(p)
pred = if size(pred, 2) == size(ode_data, 2) + 2
pred[:, 2:end-1]
@show pred.t
loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
return loss
@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all
which works.
That said, this should still get fixed to make it more robust.
Oh no, that's not the issue, it's even weirder.
using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim
u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1]) .+ tspan[1]) |> sort
tsteps = Float32.(tsteps)
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
saveat = tsteps, sensealg=InterpolatingAdjoint())
# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
prob_neuralode(u0, p)
function loss_neuralode(p)
pred = predict_neuralode(p)
pred = if size(pred, 2) == size(ode_data, 2) + 2
pred[:, 2:end-1]
loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
return loss
@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all
The issue was that tsteps
was Float64
but t
was Float32
, so during the reverse pass it never exactly hit the tstop values because they were represented in different arithmetic. The workaround is tsteps = Float32.(tsteps)
, but the real way to fix this is to make the ODE solver automatically convert any saveat
or tstops
chosen into the typeof(t)
thanks for the rapid response and for the fast fix
I was trying to adapt this example with training points that are not evenly spaced and it seems that with the default
, gradient is zero. I was told to open an issue in slack so here we are.would be a minimal example of what i was trying to do. either switching the definition of
to the uniformly spaced one or chancing the sensealg toForwardDiffSensitivity
solves the problem