CliMA / ClimaCore.jl

CliMA model dycore
https://clima.github.io/ClimaCore.jl/dev
Apache License 2.0
87 stars 8 forks source link

Inference failure in broadcast expression #1981

Closed charleskawczynski closed 1 month ago

charleskawczynski commented 1 month ago

Found in https://github.com/CliMA/ClimaAtmos.jl/pull/3290:

ci/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl:176
--
  | │││┌ materialize!(dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:911
  | ││││┌ materialize!(::ClimaCore.Fields.FieldStyle{…}, dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:914
  | │││││┌ copyto!(dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.Fields /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/Fields/broadcast.jl:149
  | ││││││┌ copyto!(dest::ClimaCore.DataLayouts.VIJFH{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:5
  | │││││││┌ copyto!(dest::ClimaCore.DataLayouts.VIJFH{…}, bc::Base.Broadcast.Broadcasted{…}, ::ClimaCore.DataLayouts.ToCPU) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:148
  | ││││││││┌ copyto!(dest::ClimaCore.DataLayouts.VF{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:5
  | │││││││││┌ copyto!(dest::ClimaCore.DataLayouts.VF{…}, bc::Base.Broadcast.Broadcasted{…}, ::ClimaCore.DataLayouts.ToCPU) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:120
  | ││││││││││┌ getindex(bc::Base.Broadcast.Broadcasted{…}, I::CartesianIndex{…}) @ Base.Broadcast ./broadcast.jl:635
  | │││││││││││┌ checkbounds(bc::Base.Broadcast.Broadcasted{…}, I::CartesianIndex{…}) @ Base.Broadcast ./broadcast.jl:647
  | ││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{ClimaCore.DataLayouts.VFStyle{…}, Nothing, typeof(ifelse), Tuple{…}}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││┌ _axes(bc::Base.Broadcast.Broadcasted{ClimaCore.DataLayouts.VFStyle{…}, Nothing, typeof(ifelse), Tuple{…}}, ::Nothing) @ Base.Broadcast ./broadcast.jl:236
  | ││││││││││││││┌ combine_axes(::Base.Broadcast.Broadcasted{…}, ::Base.Broadcast.Broadcasted{…}, ::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:523
  | │││││││││││││││┌ combine_axes(A::Base.Broadcast.Broadcasted{…}, B::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:524
  | ││││││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││││││┌ _axes(bc::Base.Broadcast.Broadcasted{…}, ::Nothing) @ Base.Broadcast ./broadcast.jl:236
  | ││││││││││││││││││ failed to optimize due to recursion: Base.Broadcast._axes(::Base.Broadcast.Broadcasted{…}, ::Nothing)
  | │││││││││││││││││└────────────────────
  | ││││││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││││││ failed to optimize due to recursion: axes(::Base.Broadcast.Broadcasted{…})
  | ││││││││││││││││└────────────────────
  | │││││││││││││││┌ combine_axes(A::Base.Broadcast.Broadcasted{…}, B::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:524

This points to https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L176-L180. And it looks like we see the same issue with https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L254-L264.

I think that this will be a big performance hit since it's inside getindex on the broadcasted object, so this is pretty important to fix.

charleskawczynski commented 1 month ago

Here is a reproducer:

using Test
using StaticArrays, IntervalSets
import ClimaCore
import ClimaComms
import ClimaCore.Utilities: PlusHalf, half
import ClimaCore.DataLayouts: IJFH
import ClimaCore:
    Fields,
    slab,
    Domains,
    Topologies,
    Meshes,
    Operators,
    Spaces,
    Geometry,
    Quadratures

using FastBroadcast
using LinearAlgebra: norm
using Statistics: mean
using ForwardDiff

include(
    joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
)
import .TestUtilities as TU
function toy_sphere(::Type{FT}) where {FT}
    context = ClimaComms.context()
    helem = npoly = 2
    hdomain = Domains.SphereDomain(FT(1e7))
    hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
    htopology = Topologies.Topology2D(context, hmesh)
    quad = Quadratures.GLL{npoly + 1}()
    hspace = Spaces.SpectralElementSpace2D(htopology, quad)
    vdomain = Domains.IntervalDomain(
        Geometry.ZPoint{FT}(zero(FT)),
        Geometry.ZPoint{FT}(FT(1e4));
        boundary_names = (:bottom, :top),
    )
    vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
    vtopology = Topologies.IntervalTopology(context, vmesh)
    vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
    center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
    face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
    return (center_space, face_space)
end

struct VarTimescaleAcnv{FT}
    τ::FT
    α::FT
end
Base.broadcastable(x::VarTimescaleAcnv) = tuple(x)
function conv_q_liq_to_q_rai(
    ::VarTimescaleAcnv{FT},
    q_liq::FT,
    ρ::FT,
    N_d::FT,
) where {FT}
    return max(0, q_liq) / (1 * (N_d / 1e8)^1)
end
function ifelsekernel!(Sᵖ, ρ)
    var = VarTimescaleAcnv(1.0, 2.0)
    @. Sᵖ = ifelse(false,1.0, conv_q_liq_to_q_rai(var, 2.0, ρ, 2.0))
    return nothing
end

using JET
# https://github.com/CliMA/ClimaCore.jl/issues/1981
@testset "ifelse kernel" begin
    (cspace, fspace) = toy_sphere(Float64)
    ρ = Fields.Field(Float64, cspace)
    S = Fields.Field(Float64, cspace)
    ifelsekernel!(S, ρ)
    @test_opt ifelsekernel!(S, ρ)
end