Closed avik-pal closed 1 year ago
Another:
using Enzyme
using LinearAlgebra
using LinearSolve
using Interpolations
Enzyme.API.printall!(true)
Enzyme.API.printperf!(true)
struct model_parameters # In SI Unit
rho_ice::Float64
rho_water::Float64
g::Float64
yts::Float64
rheology_B::Float64 # may change to Vector{Float64} later
rheology_n::Float64
end
function model_parameters() #{{{
return model_parameters(917., 1023., 9.81, 3600*24*365., 6.68067e7, 3)
end# }}}
function zsToH(x, zs, zb)
# compute gradient and H
itpzs = LinearInterpolation(x, zs)
hₓ = [Interpolations.gradient(itpzs, i)[1] for i in x]
H = zs - zb
return H, hₓ
end
# Differential operator
function Dm(N, dx)
D = Array(Bidiagonal(ones(N), -ones(N-1), :U)) ./dx
D[N,1] = -1.0./dx
return D
end
# Weertman's friciton law
function frictionWeertman(C, u; m=1.0)
β = C.^2 .* (abs.(u)).^(m-1.0)
return β
end
# SSA solver
function flowlineSSA(u0, x, H, b, hₓ, dx, C; frictionlaw=frictionWeertman, params::model_parameters=model_parameters(), maxIter=100, tol=1.0e-4)
# get the model parameters
ρi = params.rho_ice
n = params.rheology_n
g = params.g
B = params.rheology_B
N = size(H, 1)
# forcing term
rhs = ρi .* g .* H.* hₓ
D1m = Dm(N, dx)
# start nonlinear iteration
uold = u0
for i= 1: maxIter
u = uold
# viscosity
uₓ² = (D1m * u).^2.0
uₓⁿ = (uₓ² .+1e-16).^ ((1.0-n)/n*0.5)
uₓⁿ[uₓⁿ .< 1] .= 1.0
etaH = 0.5 .* H .* (uₓⁿ).* B
etaHp = circshift(etaH, -1)
# friction
β = frictionlaw(C, u)
# construct matrix
dn = 2.0 .* (etaH .+ etaHp)./ (dx.^2)
cen = -2.0 .* dn .- β
up = dn
delta = 2.0 .* (etaHp .- etaH) ./ (dx.^2)
dn .= dn .- delta
up .= up .+ delta
# periodical boundary
# rhs[1] = rhs[1] - dn[1] * u[N]
# rhs[N] = rhs[N] - up[N] * u[1]
# # matrix
sysQ = Array(Tridiagonal(dn[2:N], cen, up[1:N-1]))
sysQ[1, N] = dn[1]
sysQ[N, 1] = up[N]
# for some reason the preconditioning lead to wrong solutions
#P = Diagonal(1.0 ./ cen)
#sysQ = P * sysQ
#rhs = P * rhs
prob = LinearProblem(sysQ, rhs)
sol = solve(prob)
u = sol.u
# check convergence
res = norm(u .- uold) / norm(uold)
print("residual=$res\n")
if (res < tol)
return u
end
uold = u
end
return uold
end
using Enzyme
Nx = 1000
L = 20e3
x = collect(LinRange(0,L,Nx))
ω = 2π/L
C = sqrt.((1000 .+ 1000*sin.(ω.*x))*model_parameters().yts)
function cost(C::Vector{Float64})
α = 0.1/180*π
dx = abs(x[2] -x[1])
zs = -x.*tan(α)
zb = zs .- 1000
# solve for ISMIP-HOM D geometry
u0 = 1*ones(Nx)./ model_parameters().yts
H, hₓ = zsToH(x, zs, zb)
# Set cost function
u = flowlineSSA(u0, x, H, zb, hₓ, dx, C)
J = u[1]
end
∂J_∂C = zero(C)
J = cost(C)
# plot(x, J.*model_parameters().yts)
#Call enzyme to get derivative of cost function
# Enzyme.API.looseTypeAnalysis!(true)
# Enzyme.API.strictAliasing!(false)
autodiff(Reverse, cost, Active, Duplicated(C, ∂J_∂C))
cc @bitmyte
using Enzyme
using Interpolations
Enzyme.API.printall!(true)
Enzyme.API.printperf!(true)
using Enzyme
x = Float64[]
function cost(C::Vector{Float64})
zs = -x
interpolate((x,), zs, Gridded(Linear()))
return nothing
end
autodiff(Reverse, cost, Const, Duplicated(Float64[], Float64[]))
Partial bug: we need to cache a phi node of a decayed value:
%.pre153 = addrspacecast double addrspace(13)* addrspace(10)* %.pre151 to double addrspace(13)* addrspace(11)*, !dbg !3115
br label %L112, !dbg !3114
L112: ; preds = %ok42, %pass, %L92, %L34, %L26
%.pre-phi154 = phi double addrspace(13)* addrspace(11)* [ %.pre153, %L92 ], [ %83, %pass ], [ %83, %L34 ], [ %83, %L26 ], [ %83, %ok42 ], !dbg !3115
%.pre-phi = phi i64 [ %.pre, %L92 ], [ %51, %pass ], [ %51, %L34 ], [ %51, %L26 ], [ %51, %ok42 ], !dbg !3115
Current:
using Enzyme
using Interpolations
Enzyme.API.printall!(true)
Enzyme.API.printperf!(true)
x = Float64[]
function cost(C::Vector{Float64})
# Bad store
# Interpolations.GriddedInterpolation(Float64, knots, copy(zs), Gridded(Linear()))
# deduplicate_knots!(C)
# @inbounds C[1] = 0
zs = x
knots = (zs,)
# Interpolations.GriddedInterpolation{Float64, 1, Vector{Float64}, Gridded{Linear{Throw{OnGrid}}}, Tuple{Vector{Float64}}}(knots, A, it)
# bug
interpolate(Float64, Float64, knots, zs, Gridded(Linear()))
return nothing
end
A = Float64[1, 3, 3, 7]
dA = Float64[1, 1, 1, 1]
# TODO this currently hits a GC segfault, this should be enabled
autodiff(Reverse, cost, Const, Duplicated(A, dA))
Fixed now on main @avik-pal if you can retry