SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

Fast (reverse-mode AD) hybrid ODE sensitivities - a collection of requirements (with MWE) #863

Open ThummeTo opened 1 year ago

ThummeTo commented 1 year ago

Dear @frankschae,

as promised, I tried to conclude the requirements (in form of MWEs) that are needed to train over arbitrary FMUs using FMI.jl (or probably soon FMISensitivity.jl). Both examples are very simple NeuralODEs, we don't need to train over FMUs for the MWEs.

The requirements / MWEs are (in order of priority, most important first):

Both MWEs run into the problem that the determined gradient contains NaNs, which would lead to NaNs in the parameters and later NaNs during ANN inference.

Some additional info:

Please don't hesitate to involve me if there is anything I can do to support. For example, we could open a PR with tests on basis of the MWEs and/or examples for the documentation. If there is something unclear I can post more information/code or similar.

If we get this working, we have a significant improvement for training ML-models including FMUs (and in general: hybrid ODEs).

Thank you very much & best regards, ThummeTo

PS: "Unfortunately" I am on vacation for the next three weeks :-)

--------------- MWE ---------------

using SciMLSensitivity
using Flux
using DifferentialEquations
using DiffEqCallbacks
import SciMLSensitivity.SciMLBase: RightRootFind
import SciMLSensitivity: ReverseDiff, ForwardDiff, FakeIntegrator
import Random

Random.seed!(1234)

net = Chain(Dense(2, 16, tanh),
            Dense(16, 2, tanh))

x0 = [1.0f0, 1.0f0]
tspan = (0.0f0, 3.0f0)
saveat = tspan[1]:0.1:tspan[end]
data = sin.(saveat)

params, re = Flux.destructure(net)
initial_params = copy(params)

function fx(dx, x, p, t)
    dx[:] = re(p)(x)
end

ff = ODEFunction{true}(fx, tgrad=nothing)
prob = ODEProblem{true}(ff, x0, tspan, params)

function condition(out, x, t, integrator)
    out[1] = cos(x[1])
    out[2] = sin(x[1])
end

function affect!(integrator, idx)
    u_new = x0
    integrator.u .= u_new
end

eventCb = VectorContinuousCallback(condition,
                                   affect!,
                                   2;
                                   rootfind=RightRootFind, save_positions=(false, false))

function loss(p; sensealg=nothing)
    sol = solve(prob; p=p, callback=CallbackSet(eventCb), sensealg=sensealg, saveat=saveat)

    # ReverseDiff over solution returns a Array-solution instead of an ODESolution object!
    vals = sol[1,:] 

    solsize = size(sol)
    if solsize != (length(x0), length(saveat))
        @error "Step failed with solsize = $(solsize)!"
        return Inf
    end

    return Flux.Losses.mse(data, vals)
end

# loss function for Discretize-then-Optimize (DtO) and Optimize-then-Discretize (OtD)
loss_DtO = (p) -> loss(p; sensealg=ReverseDiffAdjoint())
loss_OtD = (p) -> loss(p; sensealg=InterpolatingAdjoint(;autojacvec=ReverseDiffVJP()))

# check simple gradinets for both loss functions
for loss in (loss_DtO, loss_OtD)
    grad_fd = ForwardDiff.gradient(loss, params, 
        ForwardDiff.GradientConfig(loss, params, ForwardDiff.Chunk{32}()))
    grad_rd = ReverseDiff.gradient(loss, params)          

    # small deviations are ok, so this is good for now!
    @info "$(loss) max deviation between ForwardDiff and ReverseDiff: $(max(abs.(grad_fd.-grad_rd)...))"
end

#### 
optim = Adam(1e-5)

# do some training steps
for loss in (loss_DtO, loss_OtD)

    # reset params (so every sensealg has the same "chance")
    params[:] = initial_params[:]

    # a very simple custom train loop, that checks the gradient before applying it
    for i in 1:500

        # get the gradient
        g = ReverseDiff.gradient(loss_DtO, params)

        # check if NaNs are in there
        if any(isnan.(g)) 
            @error "\tGradient NaN at step $(i) for loss $(loss), exiting!"
            break
        end

        # apply optimization step, update parameters
        step = Flux.Optimise.apply!(optim, params, g)
        params .-= step
    end
end

----------- MWE OUTPUT -------------

[ Info: #13 max deviation between ForwardDiff and ReverseDiff: 5.048493233239526e-6
[ Info: #15 max deviation between ForwardDiff and ReverseDiff: 3.90012033929521e-6

┌ Error:        Gradient NaN at step 16 for loss #13, exiting!
└ @ Main c:\Users\...:90
┌ Error:        Gradient NaN at step 13 for loss #15, exiting!
└ @ Main c:\Users\....jl:90
ChrisRackauckas commented 1 year ago

I think this is the kind of thing we just want to be working on getting Enzyme ready for.

sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())

Why not sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)) here?

Note that without true,

function fx(dx, x, p, t) dx[:] = re(p)(x) end

This out of place form will be slower than fx(x, p, t) = re(p)(x) of course because of the scalarizing.

Also one major improvement is to use Lux instead, or for small neural networks use SimpleChains (with static arrays)

ThummeTo commented 1 year ago

Thanks for the reply! Yep ReverseDiffVJP(true) is a good point, to be honest I wasn't sure if this is allowed to use, because of the "no-branching" requirement for pre-compilation of tapes.

Migration to Lux is also on the to-do-list :-)

And I am super-curious what progress Enzyme is making (after the big steps in the last months/weeks). I will keep checking for that.

ThummeTo commented 11 months ago

Very good news: DtO works in the current release(s) if you specify a solver by hand. Sensitivities are determined correctly and without numerical instabilities/NaNs. Thank you very much @ChrisRackauckas and @frankschae. However the provided MWE as it is (without a solver specified) still fails because of the linked DiffEqBase-issue.

Current progress:

Single event at the same time instant:

Multiple events (multiple zero-crossing event conditions) at the same time instant:

So the only thing remaing is the adjoint sensitivity problem for multiple zero-crossing event conditions. Especially in my application, this is not that important, because solving FMUs backwards in time is not supported by design and causes additional overhead ...

So again, thank you very much!

PS: Are there plans for the last feature for the near future? If not, we could close this issue from my side, but I can offer to open another issue to keep track of that last feature (in case someone searches for it or similar).

ChrisRackauckas commented 11 months ago

We plan to just keep going until everything is supported.