I am using a combination of packages to use ANNs within coupled ordinary differential equations.
When coupling less than 20 ODEs everything runs smoothly, whereas for more than 20 ODEs i get A LOT of strange REPL output (possibly from Enzyme???) before the Julia session is terminated.
The warning reads:
not handling more than 6 pointer lookups deep dt:{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Pointer, [16]:Pointer, [24]:Pointer, [24,0]:Pointer, [24,0,-1]:Float@float, [24,8]:Integer, [24,9]: [...]
I tried to use modelingtoolkitize on my ODEsystem to potentially increase compatibility with the SciML packages instead of optimizing the RHS from NetworkDynamics directly. Again, everything is fine for <20 oscillators, but for more oscillators i get the warning:
# ┌ Warning: Recursive type
# │ T = ODESystem
# └ @ Enzyme ~/.julia/packages/Enzyme/3dAID/src/typetree.jl:148
Do you have any idea what is going on? Since the problem depends on the itneraction of various packages i was not able to come up with a MWE. Here is the code that produces the problem:
## Fit coupling term of swing equation with an ANN
using DiffEqFlux
using NetworkDynamics
using Graphs
using OrdinaryDiffEq
using GalacticOptim
using Random
## Defining the graph
N = 20
k = 4
g = barabasi_albert(N, k)
### Defining the network dynamics
@inline function diffusion_vertex!(dv, v, edges, p, t)
dv[1] = 0.0f0
for e in edges
dv[1] += e[1]
end
nothing
end
@inline function diffusion_edge!(e, v_s, v_d, p, t)
e[1] = 1 / 3 * (v_s[1] - v_d[1])
nothing
end
odevertex = ODEVertex(; f=diffusion_vertex!, dim=1)
staticedge = StaticEdge(; f=diffusion_edge!, dim=1, coupling=:antisymmetric)
diffusion_network! = network_dynamics(odevertex, staticedge, g)
## Simulation
# generating random values for the parameter value ω_0 of the vertices
v_pars = randn(nv(g))
# coupling stength of edges are set to 1/3
e_pars = 1 / 3 * ones(ne(g))
p = (v_pars, e_pars)
# random initial conditions
x0 = randn(Float32, nv(g))
dx = similar(x0)
datasize = 30 # Number of data points
tspan = (0.0f0, 5.0f0) # Time range
tsteps = range(tspan[1], tspan[2], length=datasize)
diff_prob = ODEProblem(diffusion_network!, x0, tspan, nothing)
diff_sol = solve(diff_prob, Tsit5(); reltol=1e-6, saveat=tsteps)
diff_data = Array(diff_sol)
## Learning the coupling function
const ann_diff = FastChain(FastDense(2, 20, tanh),
FastDense(20, 1))
@inline function ann_edge!(e, v_s, v_d, p, t)
e[1] = ann_diff([v_s[1], v_d[1]], p)[1]
nothing
end
annedge = StaticEdge(; f=ann_edge!, dim=1, coupling=:antisymmetric)
ann_network = network_dynamics(odevertex, annedge, g)
prob_neuralode = ODEProblem(ann_network, x0, tspan, initial_params(ann_diff))
# ## Using MTK to help Enzyme
# using ModelingToolkit
# sys = modelingtoolkitize(prob_neuralode)
# prob_neuralode = ODEProblem(sys, [], tspan)
function predict_neuralode(p)
tmp_prob = remake(prob_neuralode, p=p)
Array(solve(tmp_prob, Tsit5(), saveat=tsteps))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, diff_data .- pred)
return loss, pred
end
callback = function (p, l, pred)
display(l)
return false
end
callback(initial_params(ann_diff), loss_neuralode(initial_params(ann_diff))...)
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode,
prob_neuralode.p, cb=callback, maxiters=5)
# For N > 19 modelingtoolkitized system warns:
# ┌ Warning: Recursive type
# │ T = ODESystem
# └ @ Enzyme ~/.julia/packages/Enzyme/3dAID/src/typetree.jl:148
All of the features used here have been deprecated, so this issue is basically moot, but I'd like to track down what could be an Enzyme issue if we can.
I am using a combination of packages to use ANNs within coupled ordinary differential equations.
When coupling less than 20 ODEs everything runs smoothly, whereas for more than 20 ODEs i get A LOT of strange REPL output (possibly from Enzyme???) before the Julia session is terminated.
The warning reads:
and goes on for many more line before causing a termination. A partial dump is here: https://gist.github.com/lindnemi/de8f03571323c4f14ed94ab1685fea36
I tried to use
modelingtoolkitize
on my ODEsystem to potentially increase compatibility with the SciML packages instead of optimizing the RHS from NetworkDynamics directly. Again, everything is fine for <20 oscillators, but for more oscillators i get the warning:Do you have any idea what is going on? Since the problem depends on the itneraction of various packages i was not able to come up with a MWE. Here is the code that produces the problem: