SciML / Optimization.jl

Mathematical Optimization in Julia. Local, global, gradient-based and derivative-free. Linear, Quadratic, Convex, Mixed-Integer, and Nonlinear Optimization in one simple, fast, and differentiable interface.
https://docs.sciml.ai/Optimization/stable/
MIT License
696 stars 76 forks source link

Introduce documentation on how to use mini-batches with Lux #554

Open ghost opened 1 year ago

ghost commented 1 year ago

Following the documentation in https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/#Data-Iterators-and-Minibatching.

I tried to replace the Flux library with Lux as:


using Lux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity

using StableRNGs
import MLUtils: DataLoader

function newtons_cooling(du, u, p, t)
    temp = u[1]
    k, temp_m = p
    du[1] = dT = -k * (temp - temp_m)
end

function true_sol(du, u, p, t)
    true_p = [log(2) / 8.0, 100.0]
    newtons_cooling(du, u, true_p, t)
end

rng = StableRNG(1111)

ann = Lux.Chain(Dense(1, 8, tanh), Dense(8, 1, tanh))
pp, st = Lux.setup(rng, ann)

function dudt_(u, p, t)
    ann(u,p,st)[1] .* u
end

callback = function (p, l) #callback function to observe training
    display(l)
    return false
end

u0 = Float32[200.0]
datasize = 30
tspan = (0.0f0, 1.5f0)

t = range(tspan[1], tspan[2], length = datasize)
true_prob = ODEProblem(true_sol, u0, tspan)
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))

prob = ODEProblem{false}(dudt_, u0, tspan, pp)

function predict_adjoint(fullp, time_batch)
    Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
end

function loss_adjoint(fullp, batch, time_batch)
    pred = predict_adjoint(fullp, time_batch)
    sum(abs2, batch .- pred)
end

k = 10
# Pass the data for the batches as separate vectors wrapped in a tuple
train_loader = DataLoader((ode_data, t), batchsize = k)

numEpochs = 300
l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]

optfun = OptimizationFunction((θ, p, batch, time_batch) -> loss_adjoint(θ, batch,
        time_batch),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, pp)
using IterTools: ncycle
res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs),
    callback = callback)

However, this yields an error:

ERROR: MethodError: no method matching copy(::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})

Closest candidates are:
  copy(::Union{DiffEqNoiseProcess.BoxWedgeTail, DiffEqNoiseProcess.NoiseApproximation, DiffEqNoiseProcess.NoiseGrid, DiffEqNoiseProcess.NoiseWrapper, DiffEqNoiseProcess.VirtualBrownianTree})
   @ DiffEqNoiseProcess ~/.julia/packages/DiffEqNoiseProcess/VQe6Y/src/copy_noise_types.jl:55
  copy(::Random123.Threefry4x{T, R}) where {T, R}
   @ Random123 ~/.julia/packages/Random123/u5oEp/src/threefry.jl:266
  copy(::Zygote.Buffer)
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/tools/buffer.jl:64

Would you consider introducing documentation on how to use mini-batches with Lux? This is a library used for universal differential equations and it would be useful to use this approach to train UODEs with different initial conditions/ parameters.

ChrisRackauckas commented 1 year ago
using ComponentArrays
function predict_adjoint(fullp, time_batch)
    Array(solve(prob, Tsit5(), p = ComponentArray(fullp), saveat = time_batch))
end
ghost commented 1 year ago

Thanks for the input!

However, I still get an error:

ERROR: MethodError: no method matching copy(::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})

Closest candidates are:
  copy(::Union{DiffEqNoiseProcess.BoxWedgeTail, DiffEqNoiseProcess.NoiseApproximation, DiffEqNoiseProcess.NoiseGrid, DiffEqNoiseProcess.NoiseWrapper, DiffEqNoiseProcess.VirtualBrownianTree})
   @ DiffEqNoiseProcess ~/.julia/packages/DiffEqNoiseProcess/VQe6Y/src/copy_noise_types.jl:55
  copy(::Random123.Threefry4x{T, R}) where {T, R}
   @ Random123 ~/.julia/packages/Random123/u5oEp/src/threefry.jl:266
  copy(::Zygote.Buffer)
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/tools/buffer.jl:64
  ...

Edit: I think I got it to work by using a different approach.

I left the predict_adjoint method as is:

function predict_adjoint(fullp, time_batch)
    Array(solve(prob, Tsit5(), p = fullp, saveat=time_batch))
end

And changed the definition of the optprob to be: optprob = OptimizationProblem(optfun, ComponentArray(pp))

Now it runs

ChrisRackauckas commented 12 months ago

@Vaibhavdixit02 can you put this to the docs?