SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 157 forks source link

saveat type not matching time type can result in missed tstops in values not exactly represented in floating point numbers #675

Closed SobhanMP closed 2 years ago

SobhanMP commented 2 years ago

I was trying to adapt this example with training points that are not evenly spaced and it seems that with the default sensealg, gradient is zero. I was told to open an issue in slack so here we are.

using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim

u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1])  .+ tspan[1]) |> sort
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
end

# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
                  FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
                  FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); 
    saveat = tsteps, sensealg=InterpolatingAdjoint())

# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
  Array(prob_neuralode(u0, p))
end

function loss_neuralode(p)
    pred = predict_neuralode(p)

    pred = if size(pred, 2) == size(ode_data, 2) + 2
        pred[:, 2:end-1]
    else
        pred
    end

    loss = sum(abs2, ode_data .- pred) # Just sum of squared error
    return loss
end

@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all

would be a minimal example of what i was trying to do. either switching the definition of tsteps to the uniformly spaced one or chancing the sensealg to ForwardDiffSensitivity, BacksolveAdjoint, QuadratureAdjoint or ReverseDiffAdjoint solves the problem

IlyaOrson commented 2 years ago

This might be related specifically to autojacvec=ZygoteVJP(). This example returns zero gradients with both BacksolveAdjoint and InterpolatingAdjoint, while QuadratureAdjoint throws an error.

using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity

function system!(du, u, p, t, controller)

    α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0

    y1, y2 = u
    c1, c2 = controller(u, p)

    y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
    y2_prime = (β * c1 - γ * c2) * y1

    @inbounds begin
        du[1] = y1_prime
        du[2] = y2_prime
    end
end

function loss(params, prob, tsteps)
    sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())
    sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
    return -Array(sol)[2, end]  # second variable, last value, maximize
end

u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0

# dummy controller / parameters
controller = function (x,p)
       σ.(x .* p)
       end

θ = randn(2)

dudt!(du, u, p, t) = system!(du, u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)

loss(params) = loss(params, prob, tsteps)

Zygote.gradient(loss, θ)

Error:

julia> Zygote.gradient(loss, θ)
ERROR: MethodError: no method matching dudt!(::Vector{Float32}, ::Vector{Float64}, ::Float32)
Closest candidates are:
  dudt!(::Any, ::Any, ::Any, ::Any) at REPL[28]:1
ChrisRackauckas commented 2 years ago

@IlyaOrson your example is just a bad choice of VJP. I added an informative error, and if you fix your example it's fine. https://github.com/SciML/DiffEqSensitivity.jl/pull/547

using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity

function system!(u, p, t, controller)

    α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0

    y1, y2 = u
    c1, c2 = controller(u, p)

    y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
    y2_prime = (β * c1 - γ * c2) * y1

    [y1_prime,y2_prime]
end

function loss(params, prob, tsteps)
    sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())
    sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)
    return -Array(sol)[2, end]  # second variable, last value, maximize
end

u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0

controller = function (x,p)
    σ.(x .* p)
end

θ = randn(2)

dudt!(u, p, t) = system!(u, p, t, controller)
prob = ODEProblem(dudt!, u0, tspan, θ)

loss(params) = loss(params, prob, tsteps)

Zygote.gradient(loss, θ)
ChrisRackauckas commented 2 years ago

The original issue here all has a simple cause. The issue is that your last saved value is not the end of the time interval. That's rather weird: why solve to 1.5 if you're only going to use values up to 1.4396123815716677? So the code you probably want is:

using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim

u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1])  .+ tspan[1]) |> sort
append!(tsteps,1.5f0)

# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
end

# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
                  FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
                  FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
    saveat = tsteps, sensealg=InterpolatingAdjoint())

# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
  prob_neuralode(u0, p)
end

function loss_neuralode(p)
    pred = predict_neuralode(p)

    pred = if size(pred, 2) == size(ode_data, 2) + 2
        pred[:, 2:end-1]
    else
        pred
    end

    @show pred.t

    loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
    return loss
end

loss_neuralode(prob_neuralode.p)

@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all

which works.

That said, this should still get fixed to make it more robust.

ChrisRackauckas commented 2 years ago

Oh no, that's not the issue, it's even weirder.

using DiffEqFlux, DifferentialEquations, Plots, GalacticOptim

u0 = Float32[2.0; 0.0] # Initial condition
datasize = 30 # Number of data points
tspan = (0.0f0, 1.5f0) # Time range
# tsteps = range(tspan[1], tspan[2], length = datasize) # Split time range into equal steps for each data point
tsteps = (rand(datasize) .* (tspan[2] - tspan[1])  .+ tspan[1]) |> sort
tsteps = Float32.(tsteps)
# Function that will generate the data we are trying to fit
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' # Need transposes to make the matrix multiplication work
end

# Define the problem with the function above
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Solve and take just the solution array
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

# Make a neural net with a NeuralODE layer
dudt2 = FastChain((x, p) -> x.^3, # Guess a cubic function
                  FastDense(2, 50, elu), # Multilayer perceptron for the part we don't know
                  FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5();
    saveat = tsteps, sensealg=InterpolatingAdjoint())

# Array of predictions from NeuralODE with parameters p starting at initial condition u0
function predict_neuralode(p)
  prob_neuralode(u0, p)
end

function loss_neuralode(p)
    pred = predict_neuralode(p)

    pred = if size(pred, 2) == size(ode_data, 2) + 2
        pred[:, 2:end-1]
    else
        pred
    end

    loss = sum(abs2, ode_data .- Array(pred)) # Just sum of squared error
    return loss
end

loss_neuralode(prob_neuralode.p)

@show Zygote.gradient(loss_neuralode, prob_neuralode.p)[1] .|> iszero |> all

The issue was that tsteps was Float64 but t was Float32, so during the reverse pass it never exactly hit the tstop values because they were represented in different arithmetic. The workaround is tsteps = Float32.(tsteps), but the real way to fix this is to make the ODE solver automatically convert any saveat or tstops chosen into the typeof(t).

ChrisRackauckas commented 2 years ago

Fixed in https://github.com/SciML/DiffEqSensitivity.jl/pull/548.

SobhanMP commented 2 years ago

thanks for the rapid response and for the fast fix