@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
tspan
args
kwargs
end
function NeuralODE(model, tspan, args...; kwargs...)
!(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
return NeuralODE(model, tspan, args, kwargs)
end
function (n::NeuralODE)(x, p, st)
model = StatefulLuxLayer(n.model, nothing, st)
dudt(u, p, t) = model(u, p)
ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
prob = ODEProblem{false}(ff, x, n.tspan, p)
return (
solve(prob, n.args...;
sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...),
model.st)
end
This would become: (argument splatting args... won't work but keyword argument splatting kwargs... is fine)
function NeuralODE(model, tspan, solver = nothing; kwargs...)
!(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model))
return @compact(; model, tspan, solver, sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()), kwargs...) do x, p
dudt(u, p, t) = model(u, p)
prob = ODEProblem(ODEFunction{false}(dudt; tgrad = basic_tgrad), x, n.tspan, p.model)
@return solve(prob, solver; sensealg, kwargs...)
end
end
Also this handles all the boxing issues automatically (the reason we had to add the StatefulLuxLayer)
Not sure if this is considered breaking. The end user wont be able to do foo(::NeuralODE) after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).
Current version
This would become: (argument splatting
args...
won't work but keyword argument splattingkwargs...
is fine)Also this handles all the boxing issues automatically (the reason we had to add the
StatefulLuxLayer
)Not sure if this is considered breaking. The end user wont be able to do
foo(::NeuralODE)
after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).Needs https://github.com/LuxDL/Lux.jl/pull/584 which will be released later today