SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

sciml_train: Strange fatal error with complicated ODE, depends on system dimension #706

Open lindnemi opened 2 years ago

lindnemi commented 2 years ago

I am using a combination of packages to use ANNs within coupled ordinary differential equations.

When coupling less than 20 ODEs everything runs smoothly, whereas for more than 20 ODEs i get A LOT of strange REPL output (possibly from Enzyme???) before the Julia session is terminated.

The warning reads:

not handling more than 6 pointer lookups deep dt:{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Pointer, [16]:Pointer, [24]:Pointer, [24,0]:Pointer, [24,0,-1]:Float@float, [24,8]:Integer, [24,9]: [...]

and goes on for many more line before causing a termination. A partial dump is here: https://gist.github.com/lindnemi/de8f03571323c4f14ed94ab1685fea36

I tried to use modelingtoolkitize on my ODEsystem to potentially increase compatibility with the SciML packages instead of optimizing the RHS from NetworkDynamics directly. Again, everything is fine for <20 oscillators, but for more oscillators i get the warning:

# ┌ Warning: Recursive type
# │   T = ODESystem
# └ @ Enzyme ~/.julia/packages/Enzyme/3dAID/src/typetree.jl:148

Do you have any idea what is going on? Since the problem depends on the itneraction of various packages i was not able to come up with a MWE. Here is the code that produces the problem:

## Fit coupling term of swing equation with an ANN

using DiffEqFlux
using NetworkDynamics
using Graphs
using OrdinaryDiffEq
using GalacticOptim
using Random

## Defining the graph

N = 20
k = 4
g = barabasi_albert(N, k)

### Defining the network dynamics

@inline function diffusion_vertex!(dv, v, edges, p, t)
    dv[1] = 0.0f0
    for e in edges
        dv[1] += e[1]
    end
    nothing
end

@inline function diffusion_edge!(e, v_s, v_d, p, t)
    e[1] = 1 / 3 * (v_s[1] - v_d[1])
    nothing
end

odevertex = ODEVertex(; f=diffusion_vertex!, dim=1)
staticedge = StaticEdge(; f=diffusion_edge!, dim=1, coupling=:antisymmetric)
diffusion_network! = network_dynamics(odevertex, staticedge, g)

## Simulation 

# generating random values for the parameter value ω_0 of the vertices
v_pars = randn(nv(g))
# coupling stength of edges are set to 1/3
e_pars = 1 / 3 * ones(ne(g))
p = (v_pars, e_pars)

# random initial conditions
x0 = randn(Float32, nv(g))
dx = similar(x0)
datasize = 30 # Number of data points
tspan = (0.0f0, 5.0f0) # Time range
tsteps = range(tspan[1], tspan[2], length=datasize)

diff_prob = ODEProblem(diffusion_network!, x0, tspan, nothing)
diff_sol = solve(diff_prob, Tsit5(); reltol=1e-6, saveat=tsteps)
diff_data = Array(diff_sol)

## Learning the coupling function

const ann_diff = FastChain(FastDense(2, 20, tanh),
    FastDense(20, 1))

@inline function ann_edge!(e, v_s, v_d, p, t)
    e[1] = ann_diff([v_s[1], v_d[1]], p)[1]
    nothing
end

annedge = StaticEdge(; f=ann_edge!, dim=1, coupling=:antisymmetric)
ann_network = network_dynamics(odevertex, annedge, g)

prob_neuralode = ODEProblem(ann_network, x0, tspan, initial_params(ann_diff))

# ## Using MTK to help Enzyme
# using ModelingToolkit
# sys = modelingtoolkitize(prob_neuralode)
# prob_neuralode = ODEProblem(sys, [], tspan)

function predict_neuralode(p)
    tmp_prob = remake(prob_neuralode, p=p)
    Array(solve(tmp_prob, Tsit5(), saveat=tsteps))
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, diff_data .- pred)
    return loss, pred
end

callback = function (p, l, pred)
    display(l)
    return false
end

callback(initial_params(ann_diff), loss_neuralode(initial_params(ann_diff))...)

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode,
    prob_neuralode.p, cb=callback, maxiters=5)

# For N > 19 modelingtoolkitized system warns:
# ┌ Warning: Recursive type
# │   T = ODESystem
# └ @ Enzyme ~/.julia/packages/Enzyme/3dAID/src/typetree.jl:148
ChrisRackauckas commented 2 years ago

@wsmoses do you know what this could be?

ChrisRackauckas commented 2 years ago

All of the features used here have been deprecated, so this issue is basically moot, but I'd like to track down what could be an Enzyme issue if we can.