YichengDWu / Sophon.jl

Efficient, Accurate, and Streamlined Training of Physics-Informed Neural Networks
https://yichengdwu.github.io/Sophon.jl/dev/
MIT License
54 stars 5 forks source link

Using Sophon with custom Layers #229

Closed arthur-bizzi closed 10 months ago

arthur-bizzi commented 10 months ago

What methods should a custom layer implement to be used with the library? Currently, barebones Lux layers don't seem to work.

Take the following (very rough) MWE, based on the ODE tutorial: We attempt to solve a linear ode for multiple initial conditions. The NormalLayer struct is supposed to naively emulate a FullyConnected Layer.

using ModelingToolkit, Lux, LinearAlgebra
using Optimization, OptimizationOptimJL
using Random, IntervalSets
using Sophon

struct NormalLayer{F1} <: Lux.AbstractExplicitLayer
    dims::Int
    widen_times::Int
    layers::Int
    fwd_activation::F1 
end 

function Lux.initialstates(rng::AbstractRNG,s::NormalLayer) 
    id = (Matrix{Float64}(I,s.dims,s.dims))
    widen = vcat([id for _ in 1:s.widen_times]...)
    shrink = ones(Float64, 1, s.dims*s.widen_times)    
    return (widen=widen,shrink=shrink)
end 

function Lux.initialparameters(rng::AbstractRNG,s::NormalLayer) 
    n = s.dims*s.widen_times 
    weights = Tuple([randn(n,n) for _ in 1:s.layers])
    biases = [randn(n) for _ in 1:s.layers]
    return (weights=weights, biases = biases)
end 

function (N::NormalLayer)(y::AbstractArray, ps, st::NamedTuple)
    fwd = st.widen*y
    for i in 1:N.layers
        fwd = N.fwd_activation.(ps.weights[i]*fwd .+ ps.biases[i])
    end 
    return st.shrink*fwd, st 
end

The layer works fine, but putting it to use with the library leads to a very mysterious error message:


# Defining parameters, variables and equation
vars = @parameters t, x0, y0
funs = @variables x(..), y(..)
funs_and_vars = [f(vars...) for f in funs]

Dₜ = Differential(t)
eqs = [Dₜ(x(vars...)) ~ y(vars...),
       Dₜ(y(vars...)) ~ -x(vars...)]    
bcs = [x(0.0,x0,y0) ~ x0, y(0.0,x0,y0) ~ y0,x(2π,x0,y0) ~ x0,y(2π,x0,y0) ~ y0]

# Defining the domain
ti, tf = 0,3π
xi, xf = -3,3
yi, yf = -3,3
domain = [t ∈ ti .. tf, x0 ∈ xi .. xf, y0 ∈ yi .. yf]

#Setup the symbolic PDE System
@named linear = PDESystem(eqs, bcs, domain, vars, funs_and_vars)

rng = Random.default_rng()
Net1 = NormalLayer(3,4,3,tanh)
Net2 = NormalLayer(3,4,3,tanh)
pinn = PINN(rng; x=Net1,y=Net2) 

# Setting up the sampler, training strategy and problem
dt = 1/100
samples = Integer(ceil((tf-ti)/dt))
sampler = QuasiRandomSampler(samples,100)
strategy = NonAdaptiveTraining()

prob = Sophon.discretize(linear, pinn, sampler, strategy)

#Training callback
function callback(p, _)
    loss = sum(abs2, Sophon.residual_function_1(prob.p[1], p))
    println("loss: $loss")
    return false
end

# Solving the problem using BFGS optimization
@time res = Optimization.solve(prob, BFGS(); maxiters=25,callback = callback) # 
ERROR: MethodError: no method matching zero(::Type{Any})

Closest candidates are:
  zero(::Type{Union{Missing, T}}) where T
   @ Base missing.jl:105
  zero(::Union{Type{P}, P}) where P<:Dates.Period
   @ Dates C:\Users\55619\AppData\Local\Programs\Julia-1.9.0\share\julia\stdlib\v1.9\Dates\src\periods.jl:51
  zero(::Union{AbstractAlgebra.Generic.LaurentSeriesFieldElem{T}, AbstractAlgebra.Generic.LaurentSeriesRingElem{T}} where T<:AbstractAlgebra.RingElement)  
   @ AbstractAlgebra C:\Users\55619\.julia\packages\AbstractAlgebra\lDlYU\src\generic\LaurentSeries.jl:443
  ...

Stacktrace:
 [1] zero(#unused#::Type{Any})
   @ Base .\missing.jl:106
 [2] __solve(cache::OptimizationOptimJL.OptimJLOptimizationCache{OptimizationFunction{false, Optimization.AutoZygote, OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.var"#252#261"{Optimization.var"#251#260"{OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}}}, Optimization.var"#255#264"{Optimization.var"#251#260"{OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}}}, Optimization.var"#259#268", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, typeof(callback)})
   @ OptimizationOptimJL C:\Users\55619\.julia\packages\OptimizationOptimJL\uRfW9\src\OptimizationOptimJL.jl:231
 [3] solve!(cache::OptimizationOptimJL.OptimJLOptimizationCache{OptimizationFunction{false, Optimization.AutoZygote, OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.var"#252#261"{Optimization.var"#251#260"{OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}}}, Optimization.var"#255#264"{Optimization.var"#251#260"{OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}}}, Optimization.var"#259#268", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, typeof(callback)})
   @ SciMLBase C:\Users\55619\.julia\packages\SciMLBase\kTUaf\src\solve.jl:162
 [4] solve(::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, Sophon.var"#full_loss_function#283"{typeof(Sophon.null_additional_loss), PINN{NamedTuple{(:x, :y), Tuple{ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}, ChainState{NormalLayer{typeof(tanh)}, NamedTuple{(:widen, :shrink), Tuple{Matrix{Float64}, Matrix{Float64}}}}}}, NamedTuple{(:x, :y), Tuple{NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}, NamedTuple{(:weights, :biases), Tuple{Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Vector{Vector{Float64}}}}}}}, Sophon.var"#289#290"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentArrays.ComponentVector{Any, Vector{Any}, Tuple{ComponentArrays.Axis{(x = ViewAxis(1:39, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))), y = ViewAxis(40:78, Axis(weights = 1, biases = ViewAxis(2:37, PartitionedAxis(12, FlatAxis())))))}}}, Vector{Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Flat}; kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :callback), Tuple{Int64, typeof(callback)}}})
   @ SciMLBase C:\Users\55619\.julia\packages\SciMLBase\kTUaf\src\solve.jl:83
 [5] top-level scope
   @ .\timing.jl:273
YichengDWu commented 10 months ago

I’m gonna try to run the code later on. Did you find out why the error?

arthur-bizzi commented 10 months ago

Yeah, it has to do with the ComponentArrays interface. All should work once i manage to wrestle the parameters into a format CA understands.

Thanks a lot anyway. On a related note: are there any plans to support vector equations and outputs? I.e. have a single network for all outputs instead of a named tuple of single-output ones. Sophon currently accepts nets with n-dimensional outputs for each position, but it's unclear how it interprets them.

YichengDWu commented 10 months ago

Alright, Lux's doc does mention that your weights should be a named tuple.

As for your suggestion of using a single network's output to represent multiple dependent variables, there are currently no plan to support this as I'm not aware of any clear advantages. However, it is possible (and maybe quite easy) for one to implement it. Please refer to the documentation on shared parameters in Lux. Accordingly, you will also need to change the way PINN initialize parameters.