SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
127 stars 92 forks source link

progress bars for EnsembleProblem #514

Closed pepijndevos closed 9 months ago

pepijndevos commented 10 months ago

This is the first part of a series of PRs that adds progress support to EnsemblePoblem.

What this part does is if progress is enabled, initialize progress bars for every trajectory, and give every solve an unique ID.

For very large ensembles the expectation is that the logger shows this in a sensible way, for example, show the overall progress:

sum_logger = let progress = Dict{Symbol, Float64}()
    TransformerLogger(TerminalLogger()) do log
        if log.level == LogLevel(-1) && haskey(log.kwargs, :progress)
            @show log
            pr = log.kwargs[:progress]
            if pr isa Number
                progress[log.id] = pr
            elseif pr == "done"
                progress[log.id] = 1.0
            end
            tot = sum(values(progress))/length(progress)
            if tot>=1.0
                tot="done"
                empty!(progress)
            end
            log = merge(log, (;id=:total, message="Total", kwargs=Dict(:progress=>tot)))
        else
            log
        end
    end
end
global_logger(sum_logger)
devmotion commented 10 months ago

For very large ensembles the expectation is that the logger shows this in a sensible way, for example, show the overall progress:

Last time I checked none of the common loggers (the default logger in VSCode and TerminalLoggers) did summarize large numbers of progress bars. Everything would be displayed which lead to a terrible user experience IMO.

pepijndevos commented 10 months ago

Yea that's a fair criticism. I did it like this for a few reasons:

Maybe it'd be a good idea to contribute this log aggregation somewhere. We could even have that as the default where in __solve we just do with_logger(sum_logger(current_logger()))

devmotion commented 10 months ago

I'd like to have better progress bars for ensemble simulations (not only for SciML but e.g. also in Turing where we currently only use a single progress bar for multi-chain sampling which is updated only when a full chain is sampled and hence doesn't either lead to a good user experience) but I think it's crucial to first add better support for such summaries, ideally not only in SciML but more generally for the most common loggers such as the one in VSCode or TerminalLoggers (see e.g. https://github.com/julia-vscode/julia-vscode/issues/3297, and possibly also something like https://github.com/julia-vscode/julia-vscode/issues/3317 and https://github.com/JuliaLogging/ProgressLogging.jl/issues/23).

oscardssmith commented 10 months ago

A possible alternative approach here would be to make the ensembleprob's progress bar be purely based on the number of solutions in the ensemble that have finished.

devmotion commented 10 months ago

Yeah that's what we do in Turing currently. But the user experience is a bit suboptimal, in particular if you only sample a few chains in parallel - then it can happen that the progress bar is at 0 all the time and jumps to 1 when all chains are sampled at approx the same time.

codecov[bot] commented 10 months ago

Codecov Report

Merging #514 (a1370a0) into master (06d5c2c) will decrease coverage by 0.31%. The diff coverage is 34.24%.

@@            Coverage Diff             @@
##           master     #514      +/-   ##
==========================================
- Coverage   42.19%   41.88%   -0.31%     
==========================================
  Files          53       53              
  Lines        4072     4121      +49     
==========================================
+ Hits         1718     1726       +8     
- Misses       2354     2395      +41     
Files Coverage Δ
src/ensemble/basic_ensemble_solve.jl 51.87% <34.24%> (-15.70%) :arrow_down:

:mega: Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today!

pepijndevos commented 10 months ago

I proposed for Cedar that we'd do the progress of completed simulations but indeed the granularity is quit low if you have a lot of cores or not that many problems. So IMO the correct solution is for the loggers to handle this better, or to aggregate the results internally like I suggested.

pepijndevos commented 10 months ago

I've added an internal aggregation step that is enabled by default. This makes sure the current UX isn't horrible, and then we can take it from there.

image

pepijndevos commented 10 months ago

It turns out that if you have many cores and you set progress_steps too low it can cause quite a lot of lock contention, so I made the aggregator use trylock to mitigate that.

pepijndevos commented 10 months ago

FTR we're still investigating how to reduce overhead from progress logging, so some changes are still expected.

pepijndevos commented 10 months ago

Alright, performance overhead is now negligible using several strategies:

Profile: image

I've also added the requested weakdep on https://github.com/SciML/DiffEqBase.jl/releases/tag/v6.132.0

ChrisRackauckas commented 10 months ago

As long as https://github.com/SciML/DiffEqBase.jl/blob/master/test/downstream/inference.jl passes this should generally be good.

pepijndevos commented 10 months ago

I think that particular test seems to pass. Which of the other failures out of 70 test runs are my fault is hard for me to tell.

ChrisRackauckas commented 10 months ago

I rebased. We'll see how tests go.

ChrisRackauckas commented 10 months ago

It looks like a backwards compat dispatch is missing: https://github.com/SciML/SciMLBase.jl/actions/runs/6479781336/job/17731842462?pr=512

pepijndevos commented 10 months ago

Wait I'm confused what's going on here. I think my confusion cancelled out and maybe I did the right thing? Your comment seems to be about the stats PR, so I pushed an update there.

https://github.com/SciML/SciMLBase.jl/pull/512/commits/45efdc5a9325b02518aaba6a0542246e052a1cbf

Is there anything that needs doing for this PR?

staticfloat commented 9 months ago

@ChrisRackauckas What else needs to happen for this to get in?

ChrisRackauckas commented 9 months ago

The try/catch handling here is incompatible with Zygote https://github.com/SciML/SciMLBase.jl/actions/runs/6532475287/job/17735718809?pr=514#step:6:1447

pepijndevos commented 9 months ago

So what can we do? You need a lock for thread safe coordination and you need finally to release the lock. Do we need to register handle_message as a zygote primitive or something?

ChrisRackauckas commented 9 months ago

Even easier, just set ignore derivatives needs to be set on the logging calls so it ignores everything in the function.

Have you done regression testing to check the performance?

ChrisRackauckas commented 9 months ago

https://github.com/SciML/SciMLBase.jl/actions/runs/6532475287/job/17735719200?pr=514 is pointing out that the extra kwargs are causing issues. In order to for this to get in, every downstream solver package needs to support the new keyword arguments. This includes:

It would be fine to set it so that if the value is not the default that you just warn that progress isn't implemented (behind a if verbose)) and make sure the kwarg is allowed.

pepijndevos commented 9 months ago

I've never worked with zygote, are there some docs about how to do the derivative or is it a one line thing that would take you less time than finding the docs?

We've tested performance in Cedar and the overhead from logging is minor. I go to quite some length to avoid blocking on a lock or sending more updates than necessary. Julia logging is also such that the message isn't even computed if logging is disabled, so there should be no impact unless you enable progress logging.

ChrisRackauckas commented 9 months ago

Can you show it on something simple like the Lorenz equation with SimpleTsit5?

pepijndevos commented 9 months ago

I'm looking at the docs and all I can find is how to define a rrule? https://fluxml.ai/Zygote.jl/latest/limitations/#Solutions-1 Feels like this should go in ChainRules then? Or do you mean something else by

just set ignore derivatives needs to be set on the logging calls so it ignores everything in the function.

Ah! https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.@ignore_derivatives

pepijndevos commented 9 months ago

Well that didn't work. What gives? https://github.com/SciML/SciMLBase.jl/actions/runs/6690536216/job/18176044626?pr=514#step:6:1470

pepijndevos commented 9 months ago

Keeping track of all the PRs:

Well that's not what I was expecting.. @ChrisRackauckas what did you have in mind here? I think a simple solution would be to only pass down progress_id if progress=true so problems without progress support don't have to care about the kwargs

pepijndevos commented 9 months ago

FTR I'm getting weird precompile errors about the adjoint so maybe that's still not quite right?

┌ DiffEqBase → DiffEqBaseZygoteExt
│  WARNING: Method definition adjoint(ZygoteRules.AContext, typeof(ZygoteRules.literal_getproperty), SciMLBase.EnsembleSolution{T, N, S} where S where N where T, Base.Val{:u}) in module DiffEqBase overwritten in module SciMLBaseZygoteExt.
│  ┌ Error: Error during loading of extension SciMLBaseZygoteExt of SciMLBase, use `Base.retry_load_extensions()` to retry.
│  │   exception =
│  │    1-element ExceptionStack:
│  │    Method overwriting is not permitted during Module precompile.
pepijndevos commented 9 months ago

Am I going crazy? https://github.com/SciML/SciMLBase.jl/actions/runs/6769050135/job/18394689941#step:6:579 https://github.com/pepijndevos/SciMLBase.jl/blob/pv/progress/src/SciMLBase.jl#L25

pepijndevos commented 9 months ago

I went back to the alternative logging workaround and did some more import fixes that were hidden by my repl state I guess. This passes the downstream AD tests for me, but I think it requires a new ChainRules release before it'll work on CI.

oxinabox commented 9 months ago

ok new chainrules release is out, so restarting CI and it should pass

oscardssmith commented 9 months ago

this pr needs to bump the chainrules version requirement

pepijndevos commented 9 months ago

Last time Chris asked for a version requirement I added it and then he removed it again so idk

ChrisRackauckas commented 9 months ago

I don't think this PR ever had a version requirement on ChainRules/ChainRulesCore?

pepijndevos commented 9 months ago

No I added one for DiffEqBase because I thought you wanted that to guarantee progress_id is supported but then you removed it again so I'm fine adding a version constraint if it's clear that that's actually what you want

ChrisRackauckas commented 9 months ago

I don't see why that's related. @oscardssmith is saying you need to version bound on ChainRules/Core since you're implicitly relying on @oxinabox 's latest release for the Zygote fix on logging.

pepijndevos commented 9 months ago

Done. Just wanted to make sure.

I'm still getting that weird DiffEqBaseZygoteExt precompile error btw

pepijndevos commented 9 months ago

Here are your performance numbers btw

using OrdinaryDiffEq, BenchmarkTools
function lorenz(u, p, t)
    σ = p[1]
    ρ = p[2]
    β = p[3]
    du1 = σ * (u[2] - u[1])
    du2 = u[1] * (ρ - u[3]) - u[2]
    du3 = u[1] * u[2] - β * u[3]
    return [du1, du2, du3]
end

u0 = [1.0f0; 0.0f0; 0.0f0]
tspan = (0.0f0, 10.0f0)
p = [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)

prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
@benchmark sol = solve(monteprob, Tsit5(), EnsembleThreads(), trajectories = 10_000, saveat = 1.0f0, progress=false)
@benchmark sol = solve(monteprob, Tsit5(), EnsembleThreads(), trajectories = 10_000, saveat = 1.0f0, progress=true)
BenchmarkTools.Trial: 12 samples with 1 evaluation.
 Range (min … max):  186.853 ms … 720.231 ms  ┊ GC (min … max):  0.00% … 73.95%
 Time  (median):     500.877 ms               ┊ GC (median):    62.15%
 Time  (mean ± σ):   446.853 ms ± 206.921 ms  ┊ GC (mean ± σ):  57.52% ± 33.01%

  ██                       ▁         █ ▁    ▁             ▁ ▁ ▁  
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁█▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█▁█ ▁
  187 ms           Histogram: frequency by time          720 ms <

 Memory estimate: 3.40 GiB, allocs estimate: 45697823.

BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range (min … max):  265.698 ms … 743.625 ms  ┊ GC (min … max):  0.00% … 64.56%
 Time  (median):     558.434 ms               ┊ GC (median):    52.23%
 Time  (mean ± σ):   516.208 ms ± 154.807 ms  ┊ GC (mean ± σ):  49.51% ± 23.63%

  █                    ▁         ▁    ▁▁▁   ▁      ▁          ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁███▁▁▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁█ ▁
  266 ms           Histogram: frequency by time          744 ms <

 Memory estimate: 3.55 GiB, allocs estimate: 47195479.

Logging output of 10k simulation runs after all these optimizations is just a hand full of lines

┌ LogLevel(-1): Total
│   progress = 0.0
└ @ SciMLBase ~/code/SciMLBase.jl/src/ensemble/basic_ensemble_solve.jl:134
┌ LogLevel(-1): Total
│   message = "dt=0.07805729\nt=10.0\nmax u=25.905659"
│   progress = 0.211799999999993
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/integrators/integrator_utils.jl:153
┌ LogLevel(-1): Total
│   progress = 0.756299999999933
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/solve.jl:103
┌ LogLevel(-1): Total
│   message = "dt=0.046221733\nt=10.0\nmax u=23.776852"
│   progress = "done"
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/integrators/integrator_utils.jl:153
pepijndevos commented 9 months ago

For reference, on master

BenchmarkTools.Trial: 12 samples with 1 evaluation.
 Range (min … max):  189.615 ms … 701.238 ms  ┊ GC (min … max):  0.00% … 72.88%
 Time  (median):     493.090 ms               ┊ GC (median):    61.75%
 Time  (mean ± σ):   457.536 ms ± 160.567 ms  ┊ GC (mean ± σ):  58.25% ± 25.18%

  ██           █        █           ███  ██  █          █     █  
  ██▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁███▁▁██▁▁█▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█ ▁
  190 ms           Histogram: frequency by time          701 ms <

 Memory estimate: 3.40 GiB, allocs estimate: 45635653.

BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took 6.611 s (3.17% GC) to evaluate,
 with a memory estimate of 3.77 GiB, over 51008757 allocations.

thousands of log messages omitted for brevity ;)

ChrisRackauckas commented 9 months ago

https://github.com/SciML/SciMLBase.jl/actions/runs/6795944602/job/18475031150?pr=514 looks like the last issue.

pepijndevos commented 9 months ago

Ahhh I thought that must have been some CI glitch from the stats PR, but turns out that in one of the many rebase conflicts of this PR over its long lifetime the stats got lost.

pepijndevos commented 9 months ago

Seems fine now? Could it be. Could it really be done?

ChrisRackauckas commented 9 months ago

Before

=

BenchmarkTools.Trial: 198 samples with 1 evaluation. Range (min … max): 22.291 ms … 43.583 ms ┊ GC (min … max): 0.00% … 43.89% Time (median): 23.324 ms ┊ GC (median): 0.00% Time (mean ± σ): 25.246 ms ± 5.548 ms ┊ GC (mean ± σ): 6.91% ± 12.90%

▆█▇▇▅▃▂
███████▆▁▆▅▅▅▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▅▁▁▁▁▁▁▁▁▅▁▁▅▆▆▆▆▁▇█ ▅ 22.3 ms Histogram: log(frequency) by time 42.7 ms <

Memory estimate: 24.42 MiB, allocs estimate: 320057.

BenchmarkTools.Trial: 89 samples with 1 evaluation. Range (min … max): 41.775 ms … 83.893 ms ┊ GC (min … max): 0.00% … 34.04% Time (median): 44.861 ms ┊ GC (median): 0.00% Time (mean ± σ): 56.267 ms ± 15.007 ms ┊ GC (mean ± σ): 22.08% ± 19.77%

█▄
████▇▄▃▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▅▃▄▅▆▇▆▄▄▅▁▁▁▁▃▃ ▁ 41.8 ms Histogram: frequency by time 80.1 ms <

Memory estimate: 152.59 MiB, allocs estimate: 1190057. =#

After

=

BenchmarkTools.Trial: 221 samples with 1 evaluation. Range (min … max): 19.745 ms … 51.209 ms ┊ GC (min … max): 0.00% … 56.97% Time (median): 20.502 ms ┊ GC (median): 0.00% Time (mean ± σ): 22.704 ms ± 7.274 ms ┊ GC (mean ± σ): 8.31% ± 14.29%

▇█▅▂▁
█████▄▅▁▁▁▄▁▄▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▄▇▄▅▁▆ ▅ 19.7 ms Histogram: log(frequency) by time 50.6 ms <

Memory estimate: 24.42 MiB, allocs estimate: 320076.

BenchmarkTools.Trial: 31 samples with 1 evaluation. Range (min … max): 141.556 ms … 179.966 ms ┊ GC (min … max): 4.08% … 11.54% Time (median): 166.417 ms ┊ GC (median): 11.92% Time (mean ± σ): 166.345 ms ± 8.579 ms ┊ GC (mean ± σ): 12.01% ± 2.93%

                       ▃         █  ▃  ▃                  ▃  

▇▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇█▁▁▁▁▇▇▁▇▇█▁▇█▇▇█▁▇▇▇▇▇▁▁▁▇▇▁▇▁▇▇▁▁█ ▁ 142 ms Histogram: frequency by time 180 ms <

Memory estimate: 220.28 MiB, allocs estimate: 1785509. =#