Open charleskawczynski opened 1 month ago
Local notes:
# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline column_shmem_broadcasted_args(args::Tuple, Nv, i,j,h,v) = (
column_shmem_broadcasted(args[1], Nv, i,j,h,v),
column_shmem_broadcasted_args(Base.tail(args), Nv, i,j,h,v)...,
)
@inline column_shmem_broadcasted_args(args::Tuple{Any}, Nv, i,j,h,v) =
(column_shmem_broadcasted(args[1], Nv, i,j,h,v),)
@inline column_shmem_broadcasted_args(args::Tuple{}, Nv, i,j,h,v) = ()
@inline function column_shmem_broadcasted(
bc::StencilBroadcasted{CUDAColumnStencilStyle},
Nv, i,j,h,v
)
StencilBroadcasted{CUDAColumnStencilStyle}(
bc.op,
column_shmem_broadcasted_args(bc.args, Nv, i,j,h,v),
bc.axes
)
end
@inline function column_shmem_broadcasted(
bc::Base.Broadcast.Broadcasted{CUDAColumnStencilStyle},
Nv, i,j,h,v
)
Base.Broadcast.Broadcasted{CUDAColumnStencilStyle}(
bc.f,
column_shmem_broadcasted_args(bc.args, Nv, i,j,h,v),
bc.axes
)
end
import StaticArrays
@inline function column_shmem_broadcasted(data::DataLayouts.DataColumn, ::Val{Nv}, i,j,h,v) where {Nv}
Nf = DataLayouts.number_of_fields(data)
shmem = CUDA.CuStaticSharedArray(eltype(data), (Nv, Nf))
# @show eltype(data)
# FT = eltype(parent(data))
# shmem = StaticArrays.MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef)
# shmem = StaticArrays.MArray(eltype(data), (Nv, Nf))
rdata = DataLayouts.rebuild(data, shmem)
# CUDA.@cuprintln("Nv,Nf = $((Nv,Nf))")
# CUDA.@cuprintln("eltype(rdata,data) = $((eltype(parent(rdata)), eltype(parent(data))))")
# @inbounds parent(rdata)[1] = parent(data)[1]
@inbounds rdata[1] = data[1]
# @inbounds for f in 1:Nf
# parent(rdata)[v, f] = parent(data)[v, f]
# end
rdata
end
@inline function column_shmem_broadcasted(f::Fields.Field, ::Val{Nv}, i,j,h,v) where {Nv}
fcol = Fields.column(f, i, j, h)
col_data = Fields.field_values(fcol)
coldata_shmem = column_shmem_broadcasted(col_data, Val(Nv), i, j, h,v)
return Fields.Field(coldata_shmem, axes(fcol))
end
@inline column_shmem_broadcasted(x, Nv, i,j,h,v) = x
function copyto_stencil_kernel!(out, bc, space, bds, Nq, Nh, ::Val{Nv}) where {Nv}
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if gid ≤ Nv * Nq * Nq * Nh
(li, lw, rw, ri) = bds
(v, i, j, h) = Topologies._get_idx((Nv, Nq, Nq, Nh), gid)
bc_shmem = column_shmem_broadcasted(bc, Val(Nv+1), i,j,h,v) # Extend shmem by one to work for cell faces and centers
This may depend on https://github.com/CliMA/ClimaCore.jl/issues/1754 being fixed.
Doing this generally may require first add Nv to the type parameter space (https://github.com/CliMA/ClimaCore.jl/issues/11).