SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 157 forks source link

Start using the less verbose `Lux.@compact` API #917

Closed avik-pal closed 2 months ago

avik-pal commented 7 months ago

Current version

@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
    model::M
    tspan
    args
    kwargs
end

function NeuralODE(model, tspan, args...; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
    return NeuralODE(model, tspan, args, kwargs)
end

function (n::NeuralODE)(x, p, st)
    model = StatefulLuxLayer(n.model, nothing, st)

    dudt(u, p, t) = model(u, p)
    ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
    prob = ODEProblem{false}(ff, x, n.tspan, p)

    return (
        solve(prob, n.args...;
            sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...),
        model.st)
end

This would become: (argument splatting args... won't work but keyword argument splatting kwargs... is fine)

function NeuralODE(model, tspan, solver = nothing; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model))
    return @compact(; model, tspan, solver, sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()), kwargs...) do x, p
        dudt(u, p, t) = model(u, p)
        prob = ODEProblem(ODEFunction{false}(dudt; tgrad = basic_tgrad), x, n.tspan, p.model)
        @return solve(prob, solver; sensealg, kwargs...)
    end
end

Also this handles all the boxing issues automatically (the reason we had to add the StatefulLuxLayer)

Not sure if this is considered breaking. The end user wont be able to do foo(::NeuralODE) after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).

Needs https://github.com/LuxDL/Lux.jl/pull/584 which will be released later today