SciML / DiffEqCallbacks.jl

A library of useful callbacks for hybrid scientific machine learning (SciML) with augmented differential equation solvers
https://docs.sciml.ai/DiffEqCallbacks/stable/
Other
85 stars 44 forks source link

Memory allocations (Still) of `PresetTimeCallback` after v3.2.0 #217

Closed albertomercurio closed 2 months ago

albertomercurio commented 3 months ago

Describe the example

Hello, a couple of weeks ago I created issue #211 , showing an increasing of memory allocations for versions larger than v3.2.0. I noticed that the MWE was not very minimal, in the sense that, by adding the callback, the number of saved states increased a lot. To overcome this, here I propose a very similar working example, where I define the saveat = [list[end]] to save the state only once. In this way, I get 1.07 k allocations against the 56 of the versions <3.2.0, a big difference, since it scales also with the length of tlist.

Minimal Reproducible Example πŸ‘‡

using OrdinaryDiffEq
using DiffEqCallbacks

# Define a simple ODE system
function simple_ode!(du, u, p, t)
    mul!(du, p.A, u, p.Ξ±[1], 0)
end

# Initial condition
u0 = rand(ComplexF64, 10)
A = rand(ComplexF64, 10, 10)

# Time span for the solution
tlist = range(0, 10, 1000)
tspan = (tlist[1], tlist[end])

p = (A=A,Ξ±=rand(ComplexF64, 2))

# Define the ODE problem
prob = ODEProblem(simple_ode!, u0, tspan, p, saveat=[tlist[end]])

# Callback function to modify the parameter
function change_param!(integrator)
    integrator.p.Ξ±[1] = 0.5 # Change the parameter value
end

callback = PresetTimeCallback(tlist, change_param!, save_positions=(false,false))

# Solve the ODE with the callback
sol = solve(prob, Tsit5(), callback=callback);

@time solve(prob, Tsit5(), callback=callback);
 0.000657 seconds (1.07 k allocations: 68.859 KiB)

solve(prob, Tsit5());
@time solve(prob, Tsit5());
0.000063 seconds (48 allocations: 6.250 KiB)
ChrisRackauckas commented 3 months ago

Can you share an allocation profile?

albertomercurio commented 3 months ago
julia> Profile.print()
Overhead β•Ž [+additional indent] Count File:Line; Function
=========================================================
  β•Ž1  @Base/Base.jl:608; (::Base.var"#1055#1056")()
  β•Ž 1  @Base/Base.jl:572; profile_printing_listener()
  β•Ž  1  @Base/asyncevent.jl:159; wait
  β•Ž   1  @Base/asyncevent.jl:142; _trywait(t::Base.AsyncCondition)
  β•Ž    1  @Base/condition.jl:125; wait
  β•Ž     1  @Base/condition.jl:130; wait(c::Base.GenericCondition{Base.Threads.SpinLock}; first::Bool)
  β•Ž    β•Ž 1  @Base/task.jl:994; wait()
 1β•Ž    β•Ž  1  @Base/task.jl:985; poptask(W::Base.IntrusiveLinkedListSynchronized{Task})
  β•Ž29 @Base/task.jl:675; task_done_hook(t::Task)
  β•Ž 29 @Base/task.jl:994; wait()
28β•Ž  29 @Base/task.jl:985; poptask(W::Base.IntrusiveLinkedListSynchronized{Task})
  β•Ž1  …lib/v1.10/Distributed/src/remotecall.jl:279; (::Distributed.var"#137#139")()
  β•Ž 1  @Base/condition.jl:78; lock
  β•Ž  1  @Base/lock.jl:229; lock(f::Distributed.var"#138#140", l::ReentrantLock)
  β•Ž   1  …ib/v1.10/Distributed/src/remotecall.jl:281; #138
  β•Ž    1  @Base/condition.jl:125; wait
  β•Ž     1  @Base/condition.jl:130; wait(c::Base.GenericCondition{ReentrantLock}; first::Bool)
  β•Ž    β•Ž 1  @Base/task.jl:994; wait()
 1β•Ž    β•Ž  1  @Base/task.jl:985; poptask(W::Base.IntrusiveLinkedListSynchronized{Task})
  β•Ž1  @VSCodeServer/src/eval.jl:34; (::VSCodeServer.var"#64#65")()
  β•Ž 1  @Base/essentials.jl:889; invokelatest(::Any)
  β•Ž  1  @Base/essentials.jl:892; #invokelatest#2
  β•Ž   1  @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#111#113"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Pro…
  β•Ž    1  @Base/logging.jl:627; with_logger
  β•Ž     1  @Base/logging.jl:515; with_logstate(f::Function, logstate::Any)
  β•Ž    β•Ž 1  @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#112#114"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.…
  β•Ž    β•Ž  1  @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, ::String)
  β•Ž    β•Ž   1  @Base/Base.jl:88; eval
 1β•Ž    β•Ž    1  @Base/boot.jl:385; eval
Total snapshots: 32. Utilization: 3% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.
ChrisRackauckas commented 3 months ago

@oscardssmith can you take another look at this?

oscardssmith commented 2 months ago

should be fixed by https://github.com/SciML/DiffEqCallbacks.jl/pull/218