CliMA / ClimaDiagnostics.jl

A framework to define and output observables and statistics from CliMA simulations
Apache License 2.0
9 stars 2 forks source link

Long compile times #83

Closed charleskawczynski closed 1 month ago

charleskawczynski commented 1 month ago

It seem that ClimaDiagnostics may be largely responsible for the long compile times. This may be related to #82, or complex type signatures + https://github.com/JuliaLang/julia/issues/55807, or a combination of all three. This job in ClimaAtmos took ~98 minutes to compile the driver up to the end of the first call to step!(integrator) with diagnostics, and only 24 without. So, adding diagnostics increases compile times by 4x. Here is a reproducer in ClimaAtmos:

# julia --project=examples

# Make the driver perform one step:
open("examples/hybrid/driver.jl", "w+") do io
    println(io, "redirect_stderr(IOContext(stderr, :stacktrace_types_limited => Ref(false)))")
    println(io, "import ClimaComms")
    println(io, "ClimaComms.@import_required_backends")
    println(io, "import ClimaAtmos as CA")
    println(io, "import Random")
    println(io, "Random.seed!(1234)")
    println(io, "if !(@isdefined config)")
    println(io, "    (; config_file, job_id) = CA.commandline_kwargs()")
    println(io, "    config = CA.AtmosConfig(config_file; job_id)")
    println(io, "end")
    println(io, "simulation = CA.get_simulation(config)")
    println(io, "(; integrator) = simulation")
    println(io, "import SciMLBase")
    println(io, "sol_res = SciMLBase.step!(integrator)")
end

# Common driver configuration
config_vec = [
    "initial_condition: TRMM_LBA",
    "rad: TRMM_LBA",
    "surface_setup: TRMM_LBA",
    "turbconv: diagnostic_edmfx",
    "implicit_diffusion: true",
    "approximate_linear_solve_iters: 2",
    "prognostic_tke: true",
    "edmfx_upwinding: first_order",
    "edmfx_entr_model: \"Generalized\"",
    "edmfx_detr_model: \"Generalized\"",
    "edmfx_nh_pressure: true",
    "edmfx_sgs_mass_flux: true",
    "edmfx_sgs_diffusive_flux: true",
    "moist: equil",
    "cloud_model: \"quadrature_sgs\"",
    "call_cloud_diagnostics_per_stage: true",
    "precip_model: \"1M\"",
    "override_τ_precip: false",
    "config: box",
    "x_max: 1e8",
    "y_max: 1e8",
    "x_elem: 2",
    "y_elem: 2",
    "z_elem: 30",
    "z_max: 16400",
    "dz_bottom: 50",
    "dt: 300secs",
    "t_end: 6hours",
    "dt_save_state_to_disk: 10mins",
    "FLOAT_TYPE: \"Float64\"",
    "toml: [toml/diagnostic_edmfx.toml]",
    "netcdf_interpolation_num_points: [8, 8, 30]",
    "ode_algo: ARS343"
]

#---------------------------- No diagnostics
open("config/model_configs/diagnostic_edmfx_trmm_stretched_box.yml", "w+") do io
    for entry in config_vec
        println(io, entry)
    end
    println(io, "enable_diagnostics: false")
end
empty!(ARGS);
push!(ARGS, "--config_file", "config/model_configs/diagnostic_edmfx_trmm_stretched_box.yml");
push!(ARGS, "--job_id", "diagnostic_edmfx_trmm_stretched_box");
@time include("examples/hybrid/driver.jl")
# 1430.385636 seconds (1.33 G allocations: 78.446 GiB, 87.09% gc time, 99.79% compilation time: <1% of which was recompilation)

#---------------------------- with diagnostics
open("config/model_configs/diagnostic_edmfx_trmm_stretched_box.yml", "w+") do io
    for entry in config_vec
        println(io, entry)
    end
    println(io, "diagnostics:")
    println(io, "  - short_name: [ts, ta, thetaa, ha, pfull, rhoa, ua, va, wa, hur, hus, cl, clw, cli, hussfc, evspsbl, pr]")
    println(io, "    period: 10mins")
    println(io, "  - short_name: [arup, waup, taup, thetaaup, haup, husup, hurup, clwup, cliup, waen, tke, lmix]")
    println(io, "    period: 10mins")
    println(io, "  - short_name: [husra, hussn]")
    println(io, "    period: 10mins")
end
empty!(ARGS);
push!(ARGS, "--config_file", "config/model_configs/diagnostic_edmfx_trmm_stretched_box.yml");
push!(ARGS, "--job_id", "diagnostic_edmfx_trmm_stretched_box");
@time include("examples/hybrid/driver.jl")
# 5878.586855 seconds (1.44 G allocations: 85.778 GiB, 92.98% gc time, 99.94% compilation time: <1% of which was recompilation)
charleskawczynski commented 1 month ago

cc @Sbozzolo, @szy21

Sbozzolo commented 1 month ago

Wow! That's dramatic! Thanks for finding this. Could you run a test where you use the HDF5 writer instead? The NetCDF writer is much more complex (and inherits type instablity from NCDatasets). That would already give us a sense of where to look.

Do you also know if this is inference or LLVM time?

Sbozzolo commented 1 month ago

I know that the loops in orchestrate_diagnostics cause a lot of compiler allocations. I tried unrolling them and it led to a massive explosion in compile time.

If changing writer to HDF5 improves things, we can restructure things to store symbols instead of references for the entire object. That would simplify the types quite a lot and we can re-evalute unrolling at that point.

charleskawczynski commented 1 month ago

Yeah, it's pretty dramatic, I have a feeling that it's a combination of things, and yes, we can try with HDF5.

I haven't narrowed down between inference / LLVM time.

Sbozzolo commented 1 month ago

Here are some tests I performed. I took a baroclinc wave and added a bunch of diagnostics like this:

  - short_name: [pfull, ua, wa, va, rv, ta, ke]
    period: 1days
  - short_name: [pfull, ua, wa, va, rv, ta, ke]
    period: 2days
  - short_name: [pfull, ua, wa, va, rv, ta, ke]
    period: 3days
  - short_name: [pfull, ua, wa, va, rv, ta, ke]
....

in the config.

I tested the case with NetCDF and HDF5 writers, and listed below is approximately the time to compile step:

NetCDF: 0 diagnostics: 32 seconds 10 diagnostics: 32 seconds 20 diagnostics: 36 seconds 40 diagnostics: 42 seconds 60 diagnostics: 48 seconds

HDF5: 50 with HDF5: 33 seconds 150 with HDF5: 38 seconds

Next, I tried turning the diagnostics from a tuple to vector, these are the new times

NetCDF: 50 diagnostics: 30 seconds 150 diagnostics: 33 seconds

HDF5: 50 diagnostics: 27 seconds 150 diagnostics: 38 seconds