SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
333 stars 71 forks source link

Overhead in 'Training a Neural Ordinary Differential Equation with Mini-Batching' example #936

Open wallscheid opened 1 year ago

wallscheid commented 1 year ago

In the provided example on mini-batching (https://docs.sciml.ai/SciMLSensitivity/dev/examples/neural_ode/minibatch/) the loss function, prediction model and the initial conditions are

function dudt_(u, p, t)
    re(p)(u)[1] .* u
end

function predict_adjoint(time_batch)
    _prob = remake(prob, u0 = u0, p = θ)
    Array(solve(_prob, Tsit5(), saveat = time_batch))
end

function loss_adjoint(batch, time_batch)
    pred = predict_adjoint(time_batch)
    sum(abs2, batch - pred)#, pred
end

u0 = Float32[200.0]
datasize = 30
tspan = (0.0f0, 3.0f0)

Due to the split in multiple mini-batches, the initial condition and the starting time of each mini-batch is different, compare:

for (x, y) in train_loader
    @show x
    @show y
end 

x = Float32[200.0 199.1077 198.22334 197.3469 196.47826 195.61739 194.76418 193.9186 193.08055 192.25]
y = 0.0f0:0.10344828f0:0.9310345f0
x = Float32[191.42683 190.61102 189.8025 189.00119 188.20702 187.41994 186.6399 185.8668 185.1006 184.34125]
y = 1.0344827f0:0.10344828f0:1.9655173f0
x = Float32[183.58865 182.8428 182.10359 181.37097 180.64491 179.9253 179.21213 178.50533 177.8048 177.11053]
y = 2.0689654f0:0.10344828f0:3.0f0

If I got it right, the ODE solver needs to provide intermediate solutions within predict() always starting from t=0 until the actual mini-batch timeframe beginns, e.g., for the third mini-batch which starts at roughly 2 seconds, the ODE solver needs to process everything from t=0...2 seconds as an overhead since the actual loss is calculcated for the timeframe t=2...3 seconds.

Wouldn't it be much more efficient to initialize the predict()of each mini-batch based on the initial condition of that specific mini-batch and, therefore, save the ODE solver's overhead?

ChrisRackauckas commented 1 year ago

Wouldn't it be much more efficient to initialize the predict()of each mini-batch based on the initial condition of that specific mini-batch and, therefore, save the ODE solver's overhead?

Yes. It's just a demonstration. Maybe we can make a better example.