Closed charleskawczynski closed 1 month ago
Here is a reproducer:
using Test
using StaticArrays, IntervalSets
import ClimaCore
import ClimaComms
import ClimaCore.Utilities: PlusHalf, half
import ClimaCore.DataLayouts: IJFH
import ClimaCore:
Fields,
slab,
Domains,
Topologies,
Meshes,
Operators,
Spaces,
Geometry,
Quadratures
using FastBroadcast
using LinearAlgebra: norm
using Statistics: mean
using ForwardDiff
include(
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
)
import .TestUtilities as TU
function toy_sphere(::Type{FT}) where {FT}
context = ClimaComms.context()
helem = npoly = 2
hdomain = Domains.SphereDomain(FT(1e7))
hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
htopology = Topologies.Topology2D(context, hmesh)
quad = Quadratures.GLL{npoly + 1}()
hspace = Spaces.SpectralElementSpace2D(htopology, quad)
vdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(zero(FT)),
Geometry.ZPoint{FT}(FT(1e4));
boundary_names = (:bottom, :top),
)
vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
vtopology = Topologies.IntervalTopology(context, vmesh)
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
return (center_space, face_space)
end
struct VarTimescaleAcnv{FT}
τ::FT
α::FT
end
Base.broadcastable(x::VarTimescaleAcnv) = tuple(x)
function conv_q_liq_to_q_rai(
::VarTimescaleAcnv{FT},
q_liq::FT,
ρ::FT,
N_d::FT,
) where {FT}
return max(0, q_liq) / (1 * (N_d / 1e8)^1)
end
function ifelsekernel!(Sᵖ, ρ)
var = VarTimescaleAcnv(1.0, 2.0)
@. Sᵖ = ifelse(false,1.0, conv_q_liq_to_q_rai(var, 2.0, ρ, 2.0))
return nothing
end
using JET
# https://github.com/CliMA/ClimaCore.jl/issues/1981
@testset "ifelse kernel" begin
(cspace, fspace) = toy_sphere(Float64)
ρ = Fields.Field(Float64, cspace)
S = Fields.Field(Float64, cspace)
ifelsekernel!(S, ρ)
@test_opt ifelsekernel!(S, ρ)
end
Found in https://github.com/CliMA/ClimaAtmos.jl/pull/3290:
This points to https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L176-L180. And it looks like we see the same issue with https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L254-L264.
I think that this will be a big performance hit since it's inside
getindex
on the broadcasted object, so this is pretty important to fix.