ChrisRackauckas / universal_differential_equations

Repository for the Universal Differential Equations for Scientific Machine Learning paper, describing a computational basis for high performance SciML
https://arxiv.org/abs/2001.04385
MIT License
220 stars 59 forks source link

Non Newtonian Fluids: FENE-P example error #46

Open RajDandekar opened 2 years ago

RajDandekar commented 2 years ago

Here is the code which I have for now. For simplicity, I am currently using a single strain rate and not summing over multiple strain rates.

using DiffEqFlux, Flux
using DifferentialEquations
using Plots, Statistics
using Sundials,  Optimization, OptimizationFlux

function FENEP!(out,du,u,p,t,γd)
  # DAE definition for FENE-p
  θ₁₁,θ₂₂,θ₁₂, τ₁₁,τ₂₂,τ₁₂ = u
  λ,η,L = [2.0, 2.0, 4.0]
  a = L^2 /(L^2 -3)
  fλ = (L^2 + (λ/η/a)*(τ₂₂+τ₁₁))/(L^2 - 3)
  out[1] =  τ₁₁ + du[1] - 2*λ*γd(t)*τ₁₂/fλ
  out[2] =  τ₂₂ + du[2]
  out[3] =  τ₁₂ + du[3] - λ*γd(t)*τ₂₂/fλ - η/fλ * γd(t)

  out[4] = θ₁₁ - λ*τ₁₁/fλ
  out[5] = θ₂₂ - λ*τ₂₂/fλ
  out[6] = θ₁₂ - λ*τ₁₂/fλ
end

function find_σ_exact(tsave,γd)
  # finds the exact solution of the FENE-p equations given a strain rate fn and
  # time points to save at
  λ = 2.0
  L = 2.0
  η = 4.0
  p = [λ,η,L]
  u₀ = zeros(6)
  du₀ = [0.0, 0.0, η*γd(0.0)*(L^2-3)/L^2, 0.0,0.0,0.0]
  tspan = (Float64(tsave[1]),Float64(tsave[end]))
  differential_vars = [true,true,true,false,false,false]
  h(out, du,u,p,t) = FENEP!(out, du,u,p,t,γd)

  prob = DAEProblem(h,du₀,u₀,tspan,differential_vars=differential_vars)

  sol = solve(prob,IDA(),saveat=Float64.(tsave))
  return [Float32(σ[6]) for σ in sol.u]
end

p0_vec = Float64[]

f0_n = FastChain(FastDense(2,4,tanh), FastDense(4,1))
p0 = initial_params(f0_n)
append!(p0_vec, p0)

f1_n = FastChain(FastDense(2,4,tanh), FastDense(4,1))
p0 = initial_params(f1_n)
append!(p0_vec, p0)

function dudt_(du, u, p, t, γd)
    NN1 = abs(f1_n([u[1],γd(t) ], p[1:17])[1])
    du[1] = NN1
end

# Define the problem

α = p0_vec

#=
function predict_adjoint(θ)
  x = Array(solve(prob_pred,Tsit5(),p=θ,saveat=tsave,
                  sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
=#

function loss_adjoint(θ, γd)
    tspan = (0.0f0,6.2831f0)
    tsave = range(tspan[1],tspan[2],length=100)
    σ_exact = find_σ_exact(tsave,γd)
    σ0 =0.0f0
    u0 = [σ0]

    dudt2_(du,u,p,t) = dudt_(du,u,p,t,γd)
    prob_pred = ODEProblem{true}(dudt2_,u0,tspan)
    x = Array(solve(prob_pred,Tsit5(),p=θ,saveat=tsave,
                    sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

    P_RD = vcat(x, γd.(tsave)')
    σ_out = [f0_n(P_RD[:,i], θ[18:34])[1] for i = 1:size(P_RD,2)]
    return sum( (σ_out .- σ_exact).^2 )
end

#loss_adjoint2(θ, γd) =  sum(loss_adjoint(θ,t -> 12.0f0*cos.(ω.*t))() for ω in 1.0f0:0.2f0:2.0f0)

#γd = t -> 12.0f0*cos.(1.6.*t)

iter = 0
function callback(θ,l)
  global iter
  iter += 1
  if iter%10 == 0
    println(l)
  end
  return false
end

ω = 1.0f0
γd = t -> 12.0f0*cos.(ω.*t)

The error which I get on the last line is this:

ERROR: MethodError: no method matching length(::SciMLBase.NullParameters)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/abstractdict.jl:58
  length(::Union{DataStructures.OrderedRobinDict, DataStructures.RobinDict}) at ~/.julia/packages/DataStructures/59MD0/src/ordered_robin_dict.jl:86
  length(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at ~/.julia/packages/DataStructures/59MD0/src/container_loops.jl:322
=#

@ChrisRackauckas Is this something to do with how the DAE Problem is defined? Can you also add @ccrnn here so that we can debug this together.

ghost commented 2 years ago

Sure, i'll have a look tomorrow

ghost commented 2 years ago

@RajDandekar I don't get any error when I run this code, I guess part of it is missing, but yeah, you need

tspan = (0.0f0,6.2831f0)
tsave = range(tspan[1],tspan[2],length=100)

inside: function find_σ_exact(tsave,γd) if you want to solve the original. I don't know if that's what you need though...

RajDandekar commented 2 years ago

Oh wait, I forgot to add the Optimization line. I get the error when I try to optimize the parameters using the below lines of code:

##ERROR: NO METHOD MATCHING LENGTH (SCIML NullParameters)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_adjoint(x,γd), adtype)
optprob = Optimization.OptimizationProblem(optf, α)
res1 = Optimization.solve(optprob, ADAM(0.0001), callback = callback, maxiters = 15000)

@ccrnn Now you should get the error

ghost commented 2 years ago

So, I found an error, but it's not the error you need... Something to do with tsave was causing a stack overflow. The solver works with a simplification:

using DiffEqFlux, Lux, Flux
using DifferentialEquations
using Plots, Statistics
using Sundials,  Optimization, OptimizationFlux

function FENEP!(out,du,u,p,t,γd)
  # DAE definition for FENE-p
  θ₁₁,θ₂₂,θ₁₂, τ₁₁,τ₂₂,τ₁₂ = u
  λ,η,L = [2.0, 2.0, 4.0]
  a = L^2 /(L^2 -3)
  fλ = (L^2 + (λ/η/a)*(τ₂₂+τ₁₁))/(L^2 - 3)
  out[1] =  τ₁₁ + du[1] - 2*λ*γd(t)*τ₁₂/fλ
  out[2] =  τ₂₂ + du[2]
  out[3] =  τ₁₂ + du[3] - λ*γd(t)*τ₂₂/fλ - η/fλ * γd(t)

  out[4] = θ₁₁ - λ*τ₁₁/fλ
  out[5] = θ₂₂ - λ*τ₂₂/fλ
  out[6] = θ₁₂ - λ*τ₁₂/fλ
end

#function find_σ_exact(tsave,γd)
  # finds the exact solution of the FENE-p equations given a strain rate fn and
  # time points to save at
  ω = 1.0f0
  γd = t -> 12.0f0*cos.(ω.*t)
  λ = 2.0
  L = 2.0
  η = 4.0
  p = [λ,η,L]
  u₀ = zeros(6)
  du₀ = [0.0, 0.0, η*γd(0.0)*(L^2-3)/L^2, 0.0,0.0,0.0]
  tspan = (0.00,6.00)
  tsave = 0.06
  differential_vars = [true,true,true,false,false,false]
  h(out, du,u,p,t) = FENEP!(out, du,u,p,t,γd)

  prob = DAEProblem(h,du₀,u₀,tspan,differential_vars=differential_vars)

  sol = solve(prob,IDA(),saveat=Float64.(tsave))
  return [Float32(σ[6]) for σ in sol.u]

But your error is still happening. Past this, I just don't understand well enough what is going on - I don't really understand what is being predicted or how f0, f1, etc are related to the model. Why f0 has 2 inputs to the chain. etcetc. I never saw the original working of this, so it makes it (even) more difficult. Sorry to not be more help!