SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

The FFJORD copy-pasteable code doesn't work #931

Closed a1ix2 closed 5 months ago

a1ix2 commented 5 months ago

The copy-pasteable FFJORD example here doesn't quite work as is.

I copy-pasted the code in a brand new --temp environment (only change is maxiter=1) and get an error

using ComponentArrays, DiffEqFlux, OrdinaryDiffEq, Optimization, Distributions, Random,
      OptimizationOptimisers, OptimizationOptimJL

nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
tspan = (0.0f0, 10.0f0)

ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st)

# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))

function loss(θ)
    logpx, λ₁, λ₂ = model(train_data, θ)
    return -mean(logpx)
end

function cb(p, l)
    @info "FFJORD Training" loss=loss(p)
    return false
end

adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)

res1 = Optimization.solve(
           optprob, OptimizationOptimisers.Adam(0.01); maxiters = 1, callback = cb)
┌ Error: Exception while generating log record in module Main at REPL[103]:2
│   exception =
│    type OptimizationState has no field layer_1
│    Stacktrace:
│      [1] getproperty
│        @ ./Base.jl:37 [inlined]
│      [2] macro expansion
│        @ ~/.julia/packages/Lux/7UzHr/src/layers/containers.jl:0 [inlined]
│      [3] applychain(layers::@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}, x::SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ps::Optimization.OptimizationState{ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, ShapedAxis((3, 1))))), layer_2 = ViewAxis(7:10, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, ShapedAxis((1, 1))))))}}}, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, ShapedAxis((3, 1))))), layer_2 = ViewAxis(7:10, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, ShapedAxis((1, 1))))))}}}, Nothing, Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}})
│        @ Lux ~/.julia/packages/Lux/7UzHr/src/layers/containers.jl:478

The problem appears to be in the callback function. I don't quite understand why it doesn't work as-is, but simply replacing loss=loss(p) by loss=l does the trick.

function cb(p, l)
    @info "FFJORD Training" loss=l
    return false
end
ChrisRackauckas commented 5 months ago

Yes, thanks for the report. The change to state makes it so the direct translation is:

function cb(state, l)
    @info "FFJORD Training" loss=loss(state.u)
    return false
end

But of course as you found, the better thing is just to use the pre-computed l:

function cb(state, l)
    @info "FFJORD Training" loss=l
    return false
end

This is fixed in https://github.com/SciML/DiffEqFlux.jl/commit/2691d5915a6c20613905ca257208695e642af09e