DARPA-ASKEM / sciml-service

Simulation Service provides an interface and job runner for ASKEM models.
MIT License
3 stars 1 forks source link

Glacial Flow Zero Gradient Reproducer #181

Open jClugstor opened 1 month ago

jClugstor commented 1 month ago

This reproduces the problem I've been getting, the call to adjoint_sensitivities and to Zygote.gradient will return all zeroes. This happens for any combination of sensealg and autodiff kwarg that I've tried.

using Pkg
Pkg.activate(".")
Pkg.instantiate()

# AlgebraicJulia Dependencies
using Catlab
using Catlab.Graphics
using CombinatorialSpaces
using Decapodes
using ComponentArrays

# External Dependencies
using MLStyle
using MultiScaleArrays
using LinearAlgebra
using OrdinaryDiffEq
using JLD2
using SparseArrays
using Statistics
using GeometryBasics: Point2, Point3
Point2D = Point2{Float64};
Point3D = Point3{Float64};

using DiagrammaticEquations
using DiagrammaticEquations.Deca

@info("Packages Loaded")

# use NaNmath inside of here
halfar_eq2 = @decapode begin
  h::Form0
  Γ::Form1
  n::Constant

  ḣ == ∂ₜ(h)
  ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
end

glens_law = @decapode begin
  Γ::Form1
  (A,ρ,g,n)::Constant

  Γ == (2/(n+2))*A*(ρ*g)^n
end

@info("Decapodes Defined")

ice_dynamics_composition_diagram = @relation () begin
  dynamics(Γ,n)
  stress(Γ,n)
end

ice_dynamics_cospan = oapply(ice_dynamics_composition_diagram,
  [Open(halfar_eq2, [:Γ,:n]),
  Open(glens_law, [:Γ,:n])])
ice_dynamics = apex(ice_dynamics_cospan)
ice_dynamics1D = expand_operators(ice_dynamics)
infer_types!(ice_dynamics1D, op1_inf_rules_1D, op2_inf_rules_1D)
resolve_overloads!(ice_dynamics1D, op1_res_rules_1D, op2_res_rules_1D)

s_prime = EmbeddedDeltaSet1D{Bool, Point2D}()
add_vertices!(s_prime, 25, point=Point2D.(range(-2, 2, length=25), 0))
add_edges!(s_prime, 1:nv(s_prime)-1, 2:nv(s_prime))
orient!(s_prime)
s = EmbeddedDeltaDualComplex1D{Bool, Float64, Point2D}(s_prime)
subdivide_duals!(s, Circumcenter())

@info("Spaces Defined")

function generate(sd, my_symbol; hodge=GeometricHodge())
  op = @match my_symbol begin
    :♯ => x -> begin
      # This is an implementation of the "sharp" operator from the exterior
      # calculus, which takes co-vector fields to vector fields.
      # This could be up-streamed to the CombinatorialSpaces.jl library. (i.e.
      # this operation is not bespoke to this simulation.)
      e_vecs = map(edges(sd)) do e
        point(sd, sd[e, :∂v0]) - point(sd, sd[e, :∂v1])
      end
      neighbors = map(vertices(sd)) do v
        union(incident(sd, v, :∂v0), incident(sd, v, :∂v1))
      end
      n_vecs = map(neighbors) do es
        [e_vecs[e] for e in es]
      end
      map(neighbors, n_vecs) do es, nvs
        sum([nv * norm(nv) * x[e] for (e, nv) in zip(es, nvs)]) / sum(norm.(nvs))
      end
    end
    :mag => x -> norm.(x)
    x => error("Unmatched operator $my_symbol")
  end
  return (args...) -> op(args...)
end

decapode_code = gensim(ice_dynamics1D, dimension=1, preallocate = false)
file = open("ice_sheet1D.jl", "w")
write(file, string("decapode_f = ", decapode_code))
close(file)
include("ice_sheet1D.jl")

fₘ = decapode_f(s, generate)

function f(constants_and_parameters)
  prob = ODEProblem{true, SciMLBase.FullSpecialize}(fₘ, u₀, (0, tₑ), constants_and_parameters)
  @info("Solving")
  soln = solve(prob, Tsit5())
  @info("Done")

  # return soln(tₑ)
  sum(last(soln)) # last, not soln(tₑ) because to avoid interpolation fails when AD fails.
end

#h₀ = map(x -> try sqrt(1. - x[1]^2) catch DomainError return 0.0 end, point(s_prime))
h₀ = map(x -> exp(-2*x[1]^2), point(s_prime))

flow_rate, ice_density, u_init_arr = 1e-16, 910., h₀
n = 4.0
ρ = ice_density
g = 9.8101
A = fill(flow_rate, ne(s))
tₑ = 5e3

u₀ = ComponentArray(dynamics_h = u_init_arr)

# Note that this must be a ComponentArray to differentiate
constants_and_parameters = ComponentArray(
  n = n,
  stress_ρ = ρ,
  stress_g = g,
  stress_A = A)

y = f(constants_and_parameters)

using Optimization, OptimizationPolyalgorithms, OptimizationBBO, OptimizationOptimJL
using SciMLSensitivity, Zygote, Enzyme, ReverseDiff, ForwardDiff
Enzyme.API.runtimeActivity!(true)

data_prob = ODEProblem{true, SciMLBase.FullSpecialize}(fₘ, u₀, (0, tₑ), constants_and_parameters)
decapode_sol  = solve(data_prob, Tsit5())

reference_dat = last(decapode_sol).dynamics_h

function loss(u) #only compares last time step
    newp = ComponentArray(n = n, stress_ρ = u[1], stress_g = g, stress_A = A)
    prob = remake(data_prob, p = newp)
    sol = solve(prob, FBDF(autodiff = false), sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP(), autodiff = false))
    current_dat = last(sol).dynamics_h
    sum(abs2, reference_dat .- current_dat)
end

adjoint_sensitivities(decapode_sol, FBDF(autodiff = false), g=(u,p,t) -> loss(u), sensealg=GaussAdjoint(autodiff = false))
Zygote.gradient(loss,[800.0])

Status

Status `~/Documents/Work/dev/DecapodeCalibrateDemos/Project.toml`
  [13f3f980] CairoMakie v0.12.11
  [134e5e36] Catlab v0.16.17
  [b1c52339] CombinatorialSpaces v0.6.7
  [b0b7db55] ComponentArrays v0.15.17
  [679ab3ea] Decapodes v0.5.6
  [6f00c28b] DiagrammaticEquations v0.1.7
⌅ [7da242da] Enzyme v0.12.36
  [5789e2e9] FileIO v1.16.3
  [6a86dc24] FiniteDiff v2.24.0
  [f6369f11] ForwardDiff v0.10.36
  [5c1252a2] GeometryBasics v0.4.11
  [a98d9a8b] Interpolations v0.15.1
⌃ [033835bb] JLD2 v0.5.2
  [d8e11817] MLStyle v0.4.17
  [f9640e96] MultiScaleArrays v1.12.0
  [77ba4419] NaNMath v1.0.2
⌅ [7f7a1694] Optimization v3.28.0
⌃ [3e6eede4] OptimizationBBO v0.3.0
⌅ [36348300] OptimizationOptimJL v0.3.2
  [500b13db] OptimizationPolyalgorithms v0.2.1
  [1dea7af3] OrdinaryDiffEq v6.89.0 `~/Documents/Work/dev/OrdinaryDiffEq.jl`
  [bbf590c4] OrdinaryDiffEqCore v1.6.1 `~/Documents/Work/dev/OrdinaryDiffEq.jl/lib/OrdinaryDiffEqCore`
  [37e2e3b7] ReverseDiff v1.15.3
  [1ed8b502] SciMLSensitivity v7.66.1 `~/.julia/dev/SciMLSensitivity`
  [90137ffa] StaticArrays v1.9.7
  [7a5d1e54] TrackedFloats v2.0.0
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [2f01184e] SparseArrays v1.10.0
jClugstor commented 1 month ago

Originally the initial conditions had two very sharp corners. I changed it to a similar Gaussian, which fixed the problem I was having with out of domain errors for values of n other than 3.0. It didn't fix the problem with getting a zero gradient. By dramatically increasing the value of the values inside of the A array I can get a non-zero gradient. That does make the problem have non-physical parameters, but the gradient does work.

jClugstor commented 1 month ago

With this version of the parameters, I.C.s, and combination of sensealg and solvers we're able to get the gradient pretty snappily:

using Pkg
Pkg.activate(".")
Pkg.instantiate()

# AlgebraicJulia Dependencies
using Catlab
using Catlab.Graphics
using CombinatorialSpaces
using Decapodes
using ComponentArrays

# External Dependencies
using MLStyle
using MultiScaleArrays
using LinearAlgebra
using OrdinaryDiffEq
using JLD2
using SparseArrays
using Statistics
using GeometryBasics: Point2, Point3
Point2D = Point2{Float64};
Point3D = Point3{Float64};

using DiagrammaticEquations
using DiagrammaticEquations.Deca

@info("Packages Loaded")

# use NaNmath inside of here
halfar_eq2 = @decapode begin
  h::Form0
  Γ::Form1
  n::Constant

  ḣ == ∂ₜ(h)
  ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
end

glens_law = @decapode begin
  Γ::Form1
  (A,ρ,g,n)::Constant

  Γ == (2/(n+2))*A*(ρ*g)^n
end

@info("Decapodes Defined")

ice_dynamics_composition_diagram = @relation () begin
  dynamics(Γ,n)
  stress(Γ,n)
end

ice_dynamics_cospan = oapply(ice_dynamics_composition_diagram,
  [Open(halfar_eq2, [:Γ,:n]),
  Open(glens_law, [:Γ,:n])])
ice_dynamics = apex(ice_dynamics_cospan)
ice_dynamics1D = expand_operators(ice_dynamics)
infer_types!(ice_dynamics1D, op1_inf_rules_1D, op2_inf_rules_1D)
resolve_overloads!(ice_dynamics1D, op1_res_rules_1D, op2_res_rules_1D)

s_prime = EmbeddedDeltaSet1D{Bool, Point2D}()
add_vertices!(s_prime, 100, point=Point2D.(range(-2, 2, length=100), 0))
add_edges!(s_prime, 1:nv(s_prime)-1, 2:nv(s_prime))
orient!(s_prime)
s = EmbeddedDeltaDualComplex1D{Bool, Float64, Point2D}(s_prime)
subdivide_duals!(s, Circumcenter())

@info("Spaces Defined")

function generate(sd, my_symbol; hodge=GeometricHodge())
  op = @match my_symbol begin
    :♯ => x -> begin
      # This is an implementation of the "sharp" operator from the exterior
      # calculus, which takes co-vector fields to vector fields.
      # This could be up-streamed to the CombinatorialSpaces.jl library. (i.e.
      # this operation is not bespoke to this simulation.)
      e_vecs = map(edges(sd)) do e
        point(sd, sd[e, :∂v0]) - point(sd, sd[e, :∂v1])
      end
      neighbors = map(vertices(sd)) do v
        union(incident(sd, v, :∂v0), incident(sd, v, :∂v1))
      end
      n_vecs = map(neighbors) do es
        [e_vecs[e] for e in es]
      end
      map(neighbors, n_vecs) do es, nvs
        sum([nv * norm(nv) * x[e] for (e, nv) in zip(es, nvs)]) / sum(norm.(nvs))
      end
    end
    :mag => x -> norm.(x)
    x => error("Unmatched operator $my_symbol")
  end
  return (args...) -> op(args...)
end

decapode_code = gensim(ice_dynamics1D, dimension=1, preallocate = false)
file = open("ice_sheet1D.jl", "w")
write(file, string("decapode_f = ", decapode_code))
close(file)
include("ice_sheet1D.jl")

fₘ = decapode_f(s, generate)

function f(constants_and_parameters)
  prob = ODEProblem{true, SciMLBase.FullSpecialize}(fₘ, u₀, (0, tₑ), constants_and_parameters)
  @info("Solving")
  soln = solve(prob, Tsit5())
  @info("Done")

  # return soln(tₑ)
  sum(last(soln)) # last, not soln(tₑ) because to avoid interpolation fails when AD fails.
end

#h₀ = map(x -> try sqrt(1. - x[1]^2) catch DomainError return 0.0 end, point(s_prime))
h₀ = map(x -> exp(-2*x[1]^2), point(s_prime))

flow_rate, ice_density, u_init_arr = 1e-3, 910., h₀
n = 3.0
ρ = ice_density
g = 9.8101
A = fill(flow_rate, ne(s))
tₑ = 8e3

u₀ = ComponentArray(dynamics_h = u_init_arr)

# Note that this must be a ComponentArray to differentiate
constants_and_parameters = ComponentArray(
  n = n,
  stress_ρ = ρ,
  stress_g = g,
  stress_A = A)

y = f(constants_and_parameters)

using Optimization, OptimizationPolyalgorithms, OptimizationBBO, OptimizationOptimJL
using SciMLSensitivity, Zygote, Enzyme, ReverseDiff, ForwardDiff
Enzyme.API.runtimeActivity!(true)

data_prob = ODEProblem{true, SciMLBase.FullSpecialize}(fₘ, u₀, (0, tₑ), constants_and_parameters)
decapode_sol  = solve(data_prob, Tsit5())

reference_dat = last(decapode_sol).dynamics_h

function loss(u) #only compares last time step
    newp = ComponentArray(n = n, stress_ρ = u[1], stress_g = g, stress_A = A)
    prob = remake(data_prob, p = newp)
    sol = solve(prob, FBDF(), sensealg =  InterpolatingAdjoint(autodiff = true, autojacvec = true))
    current_dat = last(sol).dynamics_h
    sum(abs2, reference_dat .- current_dat)
end

Zygote.gradient(loss, [700.0])
jClugstor commented 1 month ago

If I try to use EnzymeVJP, this is a portion of the stacktrace, since it's huge

Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer} new: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double}
val:   %134 = bitcast i8 addrspace(11)* %69 to i64 addrspace(11)*, !dbg !81 origin=  %134 = bitcast i8 addrspace(11)* %69 to i64 addrspace(11)*, !dbg !81
MethodInstance for (::var"#15#24"{EmbeddedDeltaDualComplex1D{Bool, Float64, GeometryBasics.Point{2, Float64}}})(::Int64)

Caused by:
Stacktrace:
 [1] #15
   @ ~/Documents/Work/dev/DecapodeCalibrateDemos/GlacialFlow/glacialflow1D_calibrate_nonalloc.jl:79
jClugstor commented 1 month ago

@ChrisRackauckas