SciML / MethodOfLines.jl

Automatic Finite Difference PDE solving with Julia SciML
https://docs.sciml.ai/MethodOfLines/stable/
MIT License
165 stars 31 forks source link

Generating solution output for PDE takes longer than generating solution #374

Open ctessum opened 1 year ago

ctessum commented 1 year ago

Hello!

I'm trying to debug an issue where I have a PDE system that takes a very long time for solve to run at a low point count (equivalent to ~1000 equations) and results in an out-of memory error even on an HPC node with 40gb memory when using a a still-relatively-low point count (equivalent to ~25000 equations).

Trying to boil it down as much as possible, it's similar to this advection system:

# Shared code

using ModelingToolkit, MethodOfLines, DifferentialEquations, DomainSets
using Plots, Distributions

x_min = t_min = 0.0
x_max, t_max = 1.0, 1.0

emis = MvNormal([x_max/10, t_min], [0.05, 1])

@parameters x y t
@variables s(..)
Dt = Differential(t)
Dx = Differential(x)
Dy = Differential(y)

domains = [x ∈ Interval(x_min, x_max), t ∈ Interval(t_min, t_max)]
bcs = [s(x,t_min) ~ 0.0, s(x_min,t) ~ s(x_max,t)] 
discretization = MOLFiniteDifference([x=>20], t, approx_order=2, grid_align=center_align)

eq = [ # Emission and Advection
    Dt(s(x,t)) ~ -Dx(s(x,t)) + pdf(emis, [t,x])
]
@named pdesys = PDESystem(eq,bcs,domains,[x,t],[s(x,t)])
@time prob = discretize(pdesys,discretization)
@profview sol = solve(prob, TRBDF2(), saveat=0.1)

The code above generates the following profile, zoomed in to the actual solve call:

Screenshot 2023-02-09 at 9 14 22 AM

If I'm reading this correctly, the actual time-stepping is represented by solve_up is taking a little more than a quarter of the time, and the rest of the time is taken up by PDETimeSeriesSolution, which I understand to be everything that happens between finishing the time stepping and returning the solution.

I understand that a lot of this code is in MethodOfLines.jl instead of this package, but the observed function that's highlighted in the screenshot above is in this package so I'm posting the issue here. In the larger system that this issue is trying to represent, observed takes up a much larger portion of the overall time (like, pretty much all of it), and it seemed like what was happening was that a new observed function had to be generated and compiled for each variable at each grid point, and then type interference had to be run each time it was called, and that seemed to be taking up a lot of time.

So I guess my question is whether this seems like a possible explanation for the problem, and if so if there's a way to statically type the observed function and/or possibly use one function that gets called for each variable rather than generating a new function for each variable.

I'd also be happy to post the larger system where the problem is more obvious, but I haven't done so here because it takes several hours (or more) to run each time so is difficult to reason with.

Thanks!

ctessum commented 1 year ago

I guess I should link to the system that's causing the actual problem: https://data.earthsci.dev/dev/geosfp/#Using-data-from-GEOS-FP

This is also related to this PR: https://github.com/SciML/MethodOfLines.jl/pull/240

ChrisRackauckas commented 1 year ago

I understand that a lot of this code is in MethodOfLines.jl instead of this package, but the observed function that's highlighted in the screenshot above is in this package so I'm posting the issue here. In the larger system that this issue is trying to represent, observed takes up a much larger portion of the overall time (like, pretty much all of it), and it seemed like what was happening was that a new observed function had to be generated and compiled for each variable at each grid point, and then type interference had to be run each time it was called, and that seemed to be taking up a lot of time.

So I guess my question is whether this seems like a possible explanation for the problem, and if so if there's a way to statically type the observed function and/or possibly use one function that gets called for each variable rather than generating a new function for each variable.

It's possible, and somewhat related to the issue of how lowering currently occurs via scalars and gets O(n^2) amount of code on 2D PDEs. Function folding is required there, and that's the reason for large compile times, and is the biggest piece of work right now that is going on in MTK. One that is cleaned up, then I think preserving structure in observation functions and Jacobians is next. But it's all one thread of preserving structure in the code generator.