CliMA / MultiBroadcastFusion.jl

A Julia package for fusing multiple broadcast expressions together
MIT License
6 stars 0 forks source link

Provide mechanism to automatically partition kernels #43

Open charleskawczynski opened 5 days ago

charleskawczynski commented 5 days ago

We have a large set of kernels in ClimaAtmos, and we have to partition them as shown below, because we run into ERROR: LoadError: Kernel invocation uses too much parameter memory; | 4.586 KiB exceeds the 4.000 KiB limit imposed by sm_60 / PTX v8.2 pretty easily. Since this is device-dependent, we should probably offer a mechanism to split the fused broadcasts into segments that are bounded by the parameter memory.

    FT = eltype(thp)
    α(thp, ts) = cᵥₗ(thp) / Lf(thp, ts) * (Tₐ(thp, ts) - mp.ps.T_freeze)
    @fused_direct begin
        @. Sqₜᵖ = FT(0)
        @. Sqᵣᵖ = FT(0)
        @. Sqₛᵖ = FT(0)
        @. Seₜᵖ = FT(0)

        #! format: off
        # rain autoconversion: q_liq -> q_rain
        @. Sᵖ = ifelse(
            mp.Ndp <= 0,
            CM1.conv_q_liq_to_q_rai(mp.pr.acnv1M, qₗ(thp, ts), true),
            CM2.conv_q_liq_to_q_rai(mp.var, qₗ(thp, ts), ρ, mp.Ndp),
        )
        @. Sᵖ = min(limit(qₗ(thp, ts), dt, 5), Sᵖ)
        @. Sqₜᵖ -= Sᵖ
        @. Sqᵣᵖ += Sᵖ
        @. Seₜᵖ -= Sᵖ * (Iₗ(thp, ts) + Φ)
    end

    @fused_direct begin
        # snow autoconversion assuming no supersaturation: q_ice -> q_snow
        @. Sᵖ = min(
            limit(qᵢ(thp, ts), dt, 5),
            CM1.conv_q_ice_to_q_sno_no_supersat(mp.ps.acnv1M, qᵢ(thp, ts), true),
        )
        @. Sqₜᵖ -= Sᵖ
        @. Sqₛᵖ += Sᵖ
        @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ)

        # accretion: q_liq + q_rain -> q_rain
        @. Sᵖ = min(
            limit(qₗ(thp, ts), dt, 5),
            CM1.accretion(mp.cl, mp.pr, mp.tv.rain, mp.ce, qₗ(thp, ts), qᵣ, ρ),
        )
    end

    @fused_direct begin
        @. Sqₜᵖ -= Sᵖ
        @. Sqᵣᵖ += Sᵖ
        @. Seₜᵖ -= Sᵖ * (Iₗ(thp, ts) + Φ)

        # accretion: q_ice + q_snow -> q_snow
        @. Sᵖ = min(
            limit(qᵢ(thp, ts), dt, 5),
            CM1.accretion(mp.ci, mp.ps, mp.tv.snow, mp.ce, qᵢ(thp, ts), qₛ, ρ),
        )
        @. Sqₜᵖ -= Sᵖ
        @. Sqₛᵖ += Sᵖ
        @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ)
    end

    @fused_direct begin
        # accretion: q_liq + q_sno -> q_sno or q_rai
        # sink of cloud water via accretion cloud water + snow
        @. Sᵖ = min(
            limit(qₗ(thp, ts), dt, 5),
            CM1.accretion(mp.cl, mp.ps, mp.tv.snow, mp.ce, qₗ(thp, ts), qₛ, ρ),
        )
        # if T < T_freeze cloud droplets freeze to become snow
        # else the snow melts and both cloud water and snow become rain
        @. Sᵖ_snow = ifelse(
            Tₐ(thp, ts) < mp.ps.T_freeze,
            Sᵖ,
            FT(-1) * min(Sᵖ * α(thp, ts), limit(qₛ, dt, 5)),
        )

        @. Sqₛᵖ += Sᵖ_snow
        @. Sqₜᵖ -= Sᵖ
        @. Sqᵣᵖ += ifelse(Tₐ(thp, ts) < mp.ps.T_freeze, FT(0), Sᵖ - Sᵖ_snow)
        @. Seₜᵖ -= ifelse(
            Tₐ(thp, ts) < mp.ps.T_freeze,
            Sᵖ * (Iᵢ(thp, ts) + Φ),
            Sᵖ * (Iₗ(thp, ts) + Φ) - Sᵖ_snow * (Iₗ(thp, ts) - Iᵢ(thp, ts)),
        )
    end

    @fused_direct begin
        # accretion: q_ice + q_rai -> q_sno
        @. Sᵖ = min(
            limit(qᵢ(thp, ts), dt, 5),
            CM1.accretion(mp.ci, mp.pr, mp.tv.rain, mp.ce, qᵢ(thp, ts), qᵣ, ρ),
        )
        @. Sqₜᵖ -= Sᵖ
        @. Sqₛᵖ += Sᵖ
        @. Seₜᵖ -= Sᵖ * (Iᵢ(thp, ts) + Φ)
        # sink of rain via accretion cloud ice - rain
        @. Sᵖ = min(
            limit(qᵣ, dt, 5),
            CM1.accretion_rain_sink(mp.pr, mp.ci, mp.tv.rain, mp.ce, qᵢ(thp, ts), qᵣ, ρ),
        )
    end

    @fused_direct begin
        @. Sqᵣᵖ -= Sᵖ
        @. Sqₛᵖ += Sᵖ
        @. Seₜᵖ += Sᵖ * Lf(thp, ts)

        # accretion: q_rai + q_sno -> q_rai or q_sno
        @. Sᵖ = ifelse(
            Tₐ(thp, ts) < mp.ps.T_freeze,
            min(
                limit(qᵣ, dt, 5),
                CM1.accretion_snow_rain(mp.ps, mp.pr, mp.tv.rain, mp.tv.snow, mp.ce, qₛ, qᵣ, ρ),
            ),
            -min(
                limit(qₛ, dt, 5),
                CM1.accretion_snow_rain(mp.pr, mp.ps, mp.tv.snow, mp.tv.rain, mp.ce, qᵣ, qₛ, ρ),
            ),
        )
        @. Sqₛᵖ += Sᵖ
        @. Sqᵣᵖ -= Sᵖ
        @. Seₜᵖ += Sᵖ * Lf(thp, ts)
    end
charleskawczynski commented 4 days ago

Figuring out how to partition these kernels makes working with MultiBroadcastFusion.jl a bit brittle. This is perhaps somewhat related to #24.