SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

Allow NullParameters in more adjoint method dispatches #433

Open adolfocorreia opened 3 years ago

adolfocorreia commented 3 years ago

We are experimenting some models/architectures inspired by the NODE model. Given a point (t,x), the idea is to solve an ODE system whose definition uses a neural network (and also its derivative) and whose initial condition is x. Then, we take the ODE's solution z and evaluate z(t). In order to train the neural net, we sample some points (t,x) and for each of them compute a loss function which also solves the ODE system and evaluate z(t).

We managed to get the model working with the Optim library using the Nelder-Mead optimization method which does not require gradients, but the convergence is relatively slow. We are now trying to implement the model with DiffEqFlux to see if we get better performance with the gradient descent method. Another benefit of using DiffEqFlux is that we get automatic differentiation for "free", which is obviously required by SGD.

Unfortunately, we are getting weird errors deep in the AD code and we are having a hard time trying to overcome them. And, at this point, I really don't understand if the problem is in my usage of DiffEqFlux or if there is some bug in the library (or some of its dependencies).

I provided some sample code below for you to be able to reproduce the errors. When I run the Optim version (run_optimization_optim(16)) everything is fine, but I get "MethodError: no method matching fast_materialize(::Vector{Float32})" deep into Zygote code when running the DiffEqFlux version (run_optimization_flux()). I also tried using different sensitivity algorithms, but then I get "MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})".

Maybe you could provide some insight on what might be the issue here. Thanks!

using LinearAlgebra
using DifferentialEquations, DiffEqFlux, DiffEqSensitivity
using Flux, Flux.Data
using Optim

pi_32 = convert(Float32, π)
P = 2 * pi_32
T = 1.0f0

u₀(x::Float32) = sin(2 * pi_32 / P * x)
du₀(x::Float32) = 2 * pi_32 / P * cos(2 * pi_32 / P * x)
h = u₀
dh = du₀

# %% Sample points
using Random
Random.seed!(42)

N1 = 1000
t1 = convert.(Float32, rand(N1) .* T)
x1 = convert.(Float32, rand(N1) .* P)

# %% Model definition
function create_base_model()
    return Chain(Dense(1, 5, σ), Dense(5, 1))
end
function create_destructured_base_model()
    base_model = create_base_model()
    p, re = Flux.destructure(base_model)
    return (p, re)
end
function nn_basis_functions(model::Chain)
    f(z) = model([z])[1]
    f_z(z) = gradient(f, z)[1]
    return (f, f_z)
end

# %% Compute z function given x value (solve ODE)
function compute_z(x::Float32, f, f_z; saveat = [])
    function node_system!(dz, z, _, _)
        z1, z2 = z
        dz[1] = f(z1)
        dz[2] = f_z(z1) * z2
    end
    z0 = [x, 1.0f0]
    tspan = (0.0f0, T)
    prob = ODEProblem(node_system!, z0, tspan; saveat = saveat)

    # return solve(prob; sensealg=ZygoteAdjoint()) # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
    # return solve(prob; sensealg=SensitivityADPassThrough()) # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
    return solve(prob) # MethodError: no method matching fast_materialize(::Vector{Float32})
end

# %% Loss functions
function loss_1(S, f, f_z)
    rows = size(S, 1)
    loss = 0.0f0
    for i = 1:rows
        t = S[i, 1]
        x = S[i, 2]
        z = compute_z(x, f, f_z; saveat = [t])
        z1, z2 = z(t)
        u_t = dh(z1) * f(z1)
        u_x = dh(z1) * z2
        loss += (u_t + u_x)^2
    end
    return loss / rows
end
function loss_optim(re, θ)
    model = re(θ)
    f, f_z = nn_basis_functions(model)
    S1 = hcat(t1, x1)
    return loss_1(S1, f, f_z)
end
function loss_flux(model, t1, x1)
    f, f_z = nn_basis_functions(model)
    b1 = hcat(t1, x1)
    return loss_1(b1, f, f_z)
end

# %% Optimization loops
function run_optimization_optim(dimensions)
    options = Optim.Options(show_trace = true, show_every = 10, iterations = 10_000, time_limit = 15 * 60)
    p, re = create_destructured_base_model()
    θ_guess = convert.(Float32, randn(dimensions))
    return optimize(θ -> loss_optim(re, θ), θ_guess, NelderMead(), options)
end
function run_optimization_flux()
    model = create_base_model()
    data_loader = DataLoader((t1, x1), batchsize = 10, shuffle = true)
    loss(t1, x1) = loss_flux(model, t1, x1)
    callback() = Flux.throttle(() => @show(loss(t1, x1)), 10)
    Flux.train!(loss, params(model), data_loader, ADAM(0.05); cb = callback)
end

run_optimization_optim(16)
run_optimization_flux()
ChrisRackauckas commented 3 years ago

A few steps here. First of all, since you're doing saveat, the solution already only has the value at t, so you might as well index it:

function loss_1(S, f, f_z)
    rows = size(S, 1)
    loss = 0.0f0
    for i = 1:rows
        t = S[i, 1]
        x = S[i, 2]
        z = compute_z(x, f, f_z; saveat = [t])
        z1,z2 = z[1][1],z[1][2]
        u_t = dh(z1) * f(z1)
        u_x = dh(z1) * z2
        loss += (u_t + u_x)^2
    end
    return loss / rows
end

That solves you issue. However, something else tricky comes up. You're trying to now differentiate w.r.t. parameters and initial condition, but there are not parameters. This seems to hit an untested edge case, but it was easy to fix:

https://github.com/SciML/DiffEqSensitivity.jl/pull/425

With that PR now it all works pretty fast. But you can make it even faster by forcing forward mode.

adolfocorreia commented 3 years ago

First of all, thank you very much for the prompt response! As you suggested, I'm now indexing z instead of trying to evaluate z(t). For this specific case, it really makes more sense to do so since I only need to evaluate z at this particular point t. With the DiffEqSensitivity new release, my code is now running fine! Feel free to close this issue if you wish.

If you don't mind, though, I'd like clarify some of your comments:

  1. As I understand it, I'm only differentiating w.r.t. to the neural net parameters in the optimization loop (params(model)) and the derivatives w.r.t. to the initial conditions are not obvious to me. On the other hand, in the loss function, I do take a gradient of the model w.r.t. to z which, in its turn, depends on the initial condition (t,x). Is this the derivative w.r.t. the initial conditions you alluded to?

  2. I don't understand what you mean by "there are not parameters". Did you mean "these are not parameters"?

  3. How would I go about forcing forward mode? Using ForwardSensisitivity or ForwardDiffSensitivity? If so, check my last comment below.

  4. I'm not sure if it makes sense (I don't fully understand the adjoint method), but you might want to run the code above (with the change you suggested) with different sensitivity algorithms as indicated below, since you might uncover some other "tricky untested edge cases". Only in the first three cases I was able to run the code without errors:

solve(prob)                                       # OK
solve(prob; sensealg=BacksolveAdjoint())          # OK
solve(prob; sensealg=QuadratureAdjoint())         # OK
solve(prob; sensealg=ZygoteAdjoint())             # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
solve(prob; sensealg=SensitivityADPassThrough())  # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
solve(prob; sensealg=ForwardSensitivity())        # MethodError: no method matching length(::SciMLBase.NullParameters)
solve(prob; sensealg=ForwardDiffSensitivity())    # MethodError: no method matching seed_duals(::SciMLBase.NullParameters,
solve(prob; sensealg=InterpolatingAdjoint())      # BoundsError: attempt to access 1-element Vector{Float32} at index [0]
solve(prob; sensealg=ReverseDiffAdjoint())        # MethodError: no method matching similar(::SciMLBase.NullParameters)
solve(prob; sensealg=TrackerAdjoint())            # MethodError: no method matching param(::SciMLBase.NullParameters)
ChrisRackauckas commented 3 years ago

I'm going to kick this off to be a topic in DiffEqSensitivity about that specific topic (4).

As I understand it, I'm only differentiating w.r.t. to the neural net parameters in the optimization loop (params(model)) and the derivatives w.r.t. to the initial conditions are not obvious to me. On the other hand, in the loss function, I do take a gradient of the model w.r.t. to z which, in its turn, depends on the initial condition (t,x). Is this the derivative w.r.t. the initial conditions you alluded to?

You can do both parameters and initial condition. It'll work it out automatically. Just do stuff like:

function f(theta)
  u0 = theta[1:n]
  p = theta[n+1:end]
  _prob = remake(prob,u0=u0,p=p)
  sol = solve(_prob,alg)
  sum(abs2,sol - data)
end

I don't understand what you mean by "there are not parameters". Did you mean "these are not parameters"?

Your ODE has no parameters. That's fine, that's an issue on our end to make the methods robust to not having any. Most of the time people do have parameters, so this is a less tested area, though it is tested for many cases:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/test/null_parameters.jl

How would I go about forcing forward mode? Using ForwardSensisitivity or ForwardDiffSensitivity? If so, check my last comment below.

It'll do so automatically now, as of this week. After seeing your code, I realized we should just make a smart polyalgorithm solve your problem so that you don't have to do any work here.

I'm not sure if it makes sense (I don't fully understand the adjoint method), but you might want to run the code above (with the change you suggested) with different sensitivity algorithms as indicated below, since you might uncover some other "tricky untested edge cases". Only in the first three cases I was able to run the code without errors

Yes. The thing is that only BacksolveAdjoint, ForwardDiffSensitivity, and QuadratureAdjoint make sense when you have no parameters: InterpolatingAdjoint is strictly worse than QuadratureAdjoint for that case. So the default algorithm knows to avoid it, and I guess we never tested what happens if a user specifically asks for it in this case. Turns out it fails, so we should have a better behavior here (likely: automatically switch to QuadratureAdjoint). For ForwardSensitivity it's a known issue we need to work on. ForwardDiffSensitivity I just fixed this week. ReverseDiffAdjoint and TrackerAdjoint were unknown issues we should clean up, and those would be quick to fix.

adolfocorreia commented 3 years ago

Once again, thanks for your feedback! Things are clearer to me now. I'm glad I could help you identify these unknown problematic scenarios and improve the library. Keep up the good work!