Open RajDandekar opened 2 years ago
Sure, i'll have a look tomorrow
@RajDandekar I don't get any error when I run this code, I guess part of it is missing, but yeah, you need
tspan = (0.0f0,6.2831f0)
tsave = range(tspan[1],tspan[2],length=100)
inside: function find_σ_exact(tsave,γd) if you want to solve the original. I don't know if that's what you need though...
Oh wait, I forgot to add the Optimization line. I get the error when I try to optimize the parameters using the below lines of code:
##ERROR: NO METHOD MATCHING LENGTH (SCIML NullParameters)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_adjoint(x,γd), adtype)
optprob = Optimization.OptimizationProblem(optf, α)
res1 = Optimization.solve(optprob, ADAM(0.0001), callback = callback, maxiters = 15000)
@ccrnn Now you should get the error
So, I found an error, but it's not the error you need... Something to do with tsave was causing a stack overflow. The solver works with a simplification:
using DiffEqFlux, Lux, Flux
using DifferentialEquations
using Plots, Statistics
using Sundials, Optimization, OptimizationFlux
function FENEP!(out,du,u,p,t,γd)
# DAE definition for FENE-p
θ₁₁,θ₂₂,θ₁₂, τ₁₁,τ₂₂,τ₁₂ = u
λ,η,L = [2.0, 2.0, 4.0]
a = L^2 /(L^2 -3)
fλ = (L^2 + (λ/η/a)*(τ₂₂+τ₁₁))/(L^2 - 3)
out[1] = τ₁₁ + du[1] - 2*λ*γd(t)*τ₁₂/fλ
out[2] = τ₂₂ + du[2]
out[3] = τ₁₂ + du[3] - λ*γd(t)*τ₂₂/fλ - η/fλ * γd(t)
out[4] = θ₁₁ - λ*τ₁₁/fλ
out[5] = θ₂₂ - λ*τ₂₂/fλ
out[6] = θ₁₂ - λ*τ₁₂/fλ
end
#function find_σ_exact(tsave,γd)
# finds the exact solution of the FENE-p equations given a strain rate fn and
# time points to save at
ω = 1.0f0
γd = t -> 12.0f0*cos.(ω.*t)
λ = 2.0
L = 2.0
η = 4.0
p = [λ,η,L]
u₀ = zeros(6)
du₀ = [0.0, 0.0, η*γd(0.0)*(L^2-3)/L^2, 0.0,0.0,0.0]
tspan = (0.00,6.00)
tsave = 0.06
differential_vars = [true,true,true,false,false,false]
h(out, du,u,p,t) = FENEP!(out, du,u,p,t,γd)
prob = DAEProblem(h,du₀,u₀,tspan,differential_vars=differential_vars)
sol = solve(prob,IDA(),saveat=Float64.(tsave))
return [Float32(σ[6]) for σ in sol.u]
But your error is still happening. Past this, I just don't understand well enough what is going on - I don't really understand what is being predicted or how f0, f1, etc are related to the model. Why f0 has 2 inputs to the chain. etcetc. I never saw the original working of this, so it makes it (even) more difficult. Sorry to not be more help!
Here is the code which I have for now. For simplicity, I am currently using a single strain rate and not summing over multiple strain rates.
The error which I get on the last line is this:
@ChrisRackauckas Is this something to do with how the DAE Problem is defined? Can you also add @ccrnn here so that we can debug this together.