SpeedyWeather / SpeedyWeather.jl

Play atmospheric modelling like it's LEGO.
https://speedyweather.github.io/SpeedyWeather.jl/dev
MIT License
400 stars 24 forks source link

AbstractGridArray N-dimensional and GPU ready #520

Closed milankl closed 2 months ago

milankl commented 2 months ago

Accompanying PR to #503 but for RingGrids

TODO

This works already

julia> G = rand(FullGaussianGrid{Float16},1,2)
8×2 FullGaussianArray{Float16, 2, Matrix{Float16}}:
 0.06494  0.6514
 0.9834   0.147
 0.0845   0.536
 0.1787   0.85
 0.0635   0.2065
 0.1011   0.4233
 0.9565   0.7886
 0.673    0.991

julia> G[1] = 1
1

julia> G[2,1] = 0
0

julia> G
8×2 FullGaussianArray{Float16, 2, Matrix{Float16}}:
 1.0     0.6514
 0.0     0.147
 0.0845  0.536
 0.1787  0.85
 0.0635  0.2065
 0.1011  0.4233
 0.9565  0.7886
 0.673   0.991

julia> G[:,1]
8-element, 2-ring FullGaussianGrid{Float16}:
 1.0
 0.0
 0.0845
 0.1787
 0.0635
 0.1011
 0.9565
 0.673

and of course the unicodeplots 😉

image
milankl commented 2 months ago

With

@inline eachgrid(grid::AbstractGridArray) = CartesianIndices(size(grid)[2:end])

looping over all grid points can be done as follows

function _scale_lat!(grid::AbstractGridArray{T}, v::AbstractVector) where T
    @boundscheck get_nlat(grid) == length(v) || throw(BoundsError)

    for k in eachgrid(grid)
        for (j, ring) in enumerate(eachring(grid))
            vj = convert(T, v[j])
            for ij in ring
                grid[ij, k] *= vj
            end
        end
    end

    return grid
end

the ring indices are precomputed in grid.rings, and eachgrid is non-allocating, so the whole thing is too

julia> @btime RingGrids._scale_lat!($grid, $coslat);
  535.558 μs (0 allocations: 0 bytes)
milankl commented 2 months ago

Broadcasting works now too

julia> G = rand(OctaHEALPixGrid{Float16},1,4)
4×4, 1-ring OctaHEALPixArray{Float16, 2, Matrix{Float16}}:
 0.063   0.6636   0.4683  0.188
 0.1118  0.04248  0.3584  0.9688
 0.2769  0.8354   0.3257  0.10547
 0.988   0.6245   0.545   0.5293

julia> H = rand(OctaHEALPixGrid{Float32},1,4)
4×4, 1-ring OctaHEALPixArray{Float32, 2, Matrix{Float32}}:
 0.495338  0.510038   0.579926  0.504566
 0.550702  0.761179   0.196999  0.333867
 0.328203  0.0446362  0.817386  0.791339
 0.303796  0.884892   0.414292  0.308456

julia> G + H
4×4, 1-ring OctaHEALPixArray{Float32, 2, Matrix{Float32}}:
 0.558326  1.17361   1.04819   0.692554
 0.662518  0.80366   0.555398  1.30262
 0.605059  0.880085  1.14307   0.896808
 1.29159   1.5094    0.959214  0.837753
maximilian-gelbrecht commented 2 months ago

I see tests is now checked in the to-do list, but I'd say we should add some tests with JLArrays to make sure everything works with other array types and GPU arrays as well.

milankl commented 2 months ago

Currently the type of the data overwrites the parameters of a grid, i.e.

julia> OctaHEALPixArray(jl(ones(4,2)))
4×2, 1-ring OctaHEALPixArray{Float64, 2, JLArray{Float64, 2}}:
 1.0  1.0
 1.0  1.0
 1.0  1.0
 1.0  1.0

julia> OctaHEALPixArray{Float32, 2, Matrix{Float32}}(jl(ones(4,2)))
4×2, 1-ring OctaHEALPixArray{Float64, 2, JLArray{Float64, 2}}:
 1.0  1.0
 1.0  1.0
 1.0  1.0
 1.0  1.0

wheras julia arrays would include a conversion,

julia> Matrix{Float32}(zeros(Float64,2,2))
2×2 Matrix{Float32}:
 0.0  0.0
 0.0  0.0

should we do the same @maximilian-gelbrecht ?

milankl commented 2 months ago

I keep having problems with this bit

function Base.similar(bc::Broadcasted{AbstractGridArrayStyle{N, Grid}}, ::Type{T}) where {N, Grid, T}
    # this escapes for Array and JLArray
    # if isstructurepreserving(bc) || fzeropreserving(bc) 
    #     return Grid(Array{T}(undef, size(bc)...))
    # end
    # return similar(convert(Broadcasted{DefaultArrayStyle{ndims(bc)}}, bc), T)

    # this doesn't escape for Array but for JLArray
    return Grid(Array{T}(undef, size(bc)...))
end

if I use the commented version (as adapted from #503) then both Array and JLArray escape (= grid + grid -> array not grid), the currently uncommented version at least works for Array

julia> G = rand(OctaHEALPixGrid{Float16}, 1, 2)
4×2, 1-ring OctaHEALPixArray{Float16, 2, Matrix{Float16}}:
 0.3022  0.1201
 0.7876  0.5757
 0.6616  0.4717
 0.439   0.1323

julia> G + G   # doesn't escape, good!
4×2, 1-ring OctaHEALPixArray{Float16, 2, Matrix{Float16}}:
 0.6045  0.2402
 1.575   1.151
 1.323   0.9434
 0.878   0.2646

julia> JL = adapt(JLArray, G)
4×2, 1-ring OctaHEALPixArray{Float16, 2, JLArray{Float16, 2}}:
 0.3022  0.1201
 0.7876  0.5757
 0.6616  0.4717
 0.439   0.1323

julia> JL + JL     # escapes, baaad!
4×2 JLArray{Float16, 2}:
 0.6045  0.2402
 1.575   1.151
 1.323   0.9434
 0.878   0.2646
maximilian-gelbrecht commented 2 months ago

I'll have a look at the broadcast (and the rest of the updated PR) tomorrow. The isstructurepreserving(bc) || fzeropreserving(bc) is something that Base.Broadcast specifically defined for structured matrices like the LinearAlgebra.jl definition of lower triangular matrices. So, here we definitely don't need it.

maximilian-gelbrecht commented 2 months ago

Ok, I already did the broadcast fix now. I didn't test it properly though. But your code from above works and JL + JL returns a RingGrid. For the rest I'll do a review tomorrow from the train and then I am on vacation.

milankl commented 2 months ago

Ok, I already did the broadcast fix now. I didn't test it properly though. But your code from above works and JL + JL returns a RingGrid. For the rest I'll do a review tomorrow from the train and then I am on vacation.

Awesome!!! Thanks so much!! I see you have "simply" defined a new broadcasting style for GPUs, which makes sense, but I don't understand why that's not necessary for LowerTriangularArrays??

This is perfect timing then, because once we're happy with these pull requests I can

and once you're back we can then discuss how to make the kernels then ready for GPU.

maximilian-gelbrecht commented 2 months ago

With LowerTriangularArrays we have + and * explicitly defined and not via broadcast and all the other operation that we test are of the type LT .= A .+ B or something like that, so I guess the similar is never hit there. Maybe we could still add it, but we tested all situations that appear in the dynamical core and things like +,... etc work as well as mentioned, so we also don't really seem to need it.

milankl commented 2 months ago

I hit a few scalar indexing errors that I didn't know how to fix generally via broadcasting so I've just defined those methods to propagate to grid.data

now one can do grid == grid or grid .== grid even with JLArray without scalar indexing. I don't know what's required to solve these more generally without a new method per function but that's all we need right now for testing