CliMA / ClimaCore.jl

CliMA model dycore
https://clima.github.io/ClimaCore.jl/dev
Apache License 2.0
79 stars 7 forks source link

Use shared/local memory in FD stencils kernels #1746

Open charleskawczynski opened 1 month ago

charleskawczynski commented 1 month ago

Doing this generally may require first add Nv to the type parameter space (https://github.com/CliMA/ClimaCore.jl/issues/11).

charleskawczynski commented 1 month ago

Local notes:

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline column_shmem_broadcasted_args(args::Tuple, Nv, i,j,h,v) = (
    column_shmem_broadcasted(args[1], Nv, i,j,h,v),
    column_shmem_broadcasted_args(Base.tail(args), Nv, i,j,h,v)...,
)
@inline column_shmem_broadcasted_args(args::Tuple{Any}, Nv, i,j,h,v) =
    (column_shmem_broadcasted(args[1], Nv, i,j,h,v),)
@inline column_shmem_broadcasted_args(args::Tuple{}, Nv, i,j,h,v) = ()

@inline function column_shmem_broadcasted(
    bc::StencilBroadcasted{CUDAColumnStencilStyle},
    Nv, i,j,h,v
)
    StencilBroadcasted{CUDAColumnStencilStyle}(
        bc.op,
        column_shmem_broadcasted_args(bc.args, Nv, i,j,h,v),
        bc.axes
    )
end

@inline function column_shmem_broadcasted(
    bc::Base.Broadcast.Broadcasted{CUDAColumnStencilStyle},
    Nv, i,j,h,v
)
    Base.Broadcast.Broadcasted{CUDAColumnStencilStyle}(
        bc.f,
        column_shmem_broadcasted_args(bc.args, Nv, i,j,h,v),
        bc.axes
    )
end
import StaticArrays
@inline function column_shmem_broadcasted(data::DataLayouts.DataColumn, ::Val{Nv}, i,j,h,v) where {Nv}
    Nf = DataLayouts.number_of_fields(data)
    shmem = CUDA.CuStaticSharedArray(eltype(data), (Nv, Nf))
    # @show eltype(data)
    # FT = eltype(parent(data))
    # shmem = StaticArrays.MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef)
    # shmem = StaticArrays.MArray(eltype(data), (Nv, Nf))
    rdata = DataLayouts.rebuild(data, shmem)
    # CUDA.@cuprintln("Nv,Nf = $((Nv,Nf))")
    # CUDA.@cuprintln("eltype(rdata,data) = $((eltype(parent(rdata)), eltype(parent(data))))")

    # @inbounds parent(rdata)[1] = parent(data)[1]
    @inbounds rdata[1] = data[1]
    # @inbounds for f in 1:Nf
    #     parent(rdata)[v, f] = parent(data)[v, f]
    # end
    rdata
end
@inline function column_shmem_broadcasted(f::Fields.Field, ::Val{Nv}, i,j,h,v) where {Nv}
    fcol = Fields.column(f, i, j, h)
    col_data = Fields.field_values(fcol)
    coldata_shmem = column_shmem_broadcasted(col_data, Val(Nv), i, j, h,v)
    return Fields.Field(coldata_shmem, axes(fcol))
end
@inline column_shmem_broadcasted(x, Nv, i,j,h,v) = x

function copyto_stencil_kernel!(out, bc, space, bds, Nq, Nh, ::Val{Nv}) where {Nv}
    gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
    if gid ≤ Nv * Nq * Nq * Nh
        (li, lw, rw, ri) = bds
        (v, i, j, h) = Topologies._get_idx((Nv, Nq, Nq, Nh), gid)
        bc_shmem = column_shmem_broadcasted(bc, Val(Nv+1), i,j,h,v) # Extend shmem by one to work for cell faces and centers
charleskawczynski commented 1 month ago

This may depend on https://github.com/CliMA/ClimaCore.jl/issues/1754 being fixed.