CliMA / ClimaCore.jl

CliMA model dycore
https://clima.github.io/ClimaCore.jl/dev
Apache License 2.0
86 stars 8 forks source link

With ClimaCore 0.14.6, the remapping module in ClimaCoupler produces incorrect results #1772

Closed Sbozzolo closed 3 months ago

Sbozzolo commented 3 months ago

The ClimaCoupler test suite passes with ClimaCore 0.14.5, but when updating to 0.14.6, the remapping tests now fail. See example https://github.com/CliMA/ClimaCoupler.jl/actions/runs/9325823164/job/25675049295

I don't know if this is an issue with ClimaCoupler doing something that it shouldn't, or something that was missed in recent changes in ClimaCore. The only change that could have affected this is using Nv, I think (eg https://github.com/CliMA/ClimaCore.jl/commit/1b84475b2405b8e0d2014e32fda5a7914cf9794e)

charleskawczynski commented 3 months ago

Hm, that's unfortunate, not sure what could have gone wrong. It's possible that other updated dependencies are responsible. I'll try to make a reproducer. Here's what I have at the moment (still need to try it out):

#=
julia --project=.buildkite
julia --project=experiments/ClimaEarth
using Revise; include("../coupler_reproducer.jl")
=#
ENV["CLIMACOMMS_DEVICE"] = "CPU"
import ClimaComms
import NCDatasets
import ClimaCoreTempestRemap as CCTR
import ClimaCore: Domains, Meshes, Topologies, Spaces, Fields, Geometry, InputOutput
import Dates
import JLD2
strdate_to_datetime(strdate::String) =
    Dates.DateTime(parse(Int, strdate[1:4]), parse(Int, strdate[5:6]), parse(Int, strdate[7:8]))

function read_remapped_field(name::Symbol, datafile_latlon::String, lev_name = "z")
    out = NCDatasets.NCDataset(datafile_latlon, "r") do nc
        lon = Array(nc["lon"])
        lat = Array(nc["lat"])
        lev = lev_name in keys(nc) ? Array(nc[lev_name]) : Float64(-999)
        var = Array(nc[name])
        coords = (; lon = lon, lat = lat, lev = lev)

        (var, coords)
    end

    return out
end

function create_space(
    FT;
    comms_ctx = ClimaComms.SingletonCommsContext(),
    R = FT(6371e3),
    ne = 4,
    polynomial_degree = 3,
    nz = 1,
    height = FT(100),
)
    domain = Domains.SphereDomain(R)
    mesh = Meshes.EquiangularCubedSphere(domain, ne)

    if comms_ctx isa ClimaComms.SingletonCommsContext
        topology = Topologies.Topology2D(comms_ctx, mesh, Topologies.spacefillingcurve(mesh))
    else
        topology = Topologies.DistributedTopology2D(comms_ctx, mesh, Topologies.spacefillingcurve(mesh))
    end

    Nq = polynomial_degree + 1
    quad = Spaces.Quadratures.GLL{Nq}()
    sphere_space = Spaces.SpectralElementSpace2D(topology, quad)

    if nz > 1
        vertdomain = Domains.IntervalDomain(
            Geometry.ZPoint{FT}(0),
            Geometry.ZPoint{FT}(height);
            boundary_names = (:bottom, :top),
        )
        vertmesh = Meshes.IntervalMesh(vertdomain, nelems = nz)
        vert_topology = Topologies.IntervalTopology(comms_ctx, vertmesh)
        vert_center_space = Spaces.CenterFiniteDifferenceSpace(vert_topology)
        return Spaces.ExtrudedFiniteDifferenceSpace(sphere_space, vert_center_space)
    else
        return sphere_space
    end
end

function write_datafile_cc(datafile_cc, field, name)
    space = axes(field)
    # write data
    NCDatasets.NCDataset(datafile_cc, "c") do nc
        CCTR.def_space_coord(nc, space; type = "cgll")
        nc_field = NCDatasets.defVar(nc, name, Float64, space)
        nc_field[:, 1] = field

        nothing
    end
end

function remap_field_cgll_to_rll(name, field::Fields.Field, remap_tmpdir, datafile_rll; nlat = 90, nlon = 180)
    space = axes(field)
    hspace = :topology in propertynames(space) ? space : Spaces.horizontal_space(space)
    Nq = Spaces.Quadratures.polynomial_degree(Spaces.quadrature_style(hspace)) + 1

    # write out our cubed sphere mesh
    meshfile_cc = remap_tmpdir * "/mesh_cubedsphere.g"
    CCTR.write_exodus(meshfile_cc, Spaces.topology(hspace))

    meshfile_rll = remap_tmpdir * "/mesh_rll.g"
    CCTR.rll_mesh(meshfile_rll; nlat = nlat, nlon = nlon)

    meshfile_overlap = remap_tmpdir * "/mesh_overlap.g"
    CCTR.overlap_mesh(meshfile_overlap, meshfile_cc, meshfile_rll)

    weightfile = remap_tmpdir * "/remap_weights.nc"
    CCTR.remap_weights(weightfile, meshfile_cc, meshfile_rll, meshfile_overlap; in_type = "cgll", in_np = Nq)

    datafile_cc = remap_tmpdir * "/datafile_cc.nc"
    write_datafile_cc(datafile_cc, field, name)

    CCTR.apply_remap( # TODO: this can be done online
        datafile_rll,
        datafile_cc,
        weightfile,
        [string(name)],
    )
end

function cgll2latlonz(field; DIR = "cgll2latlonz_dir", nlat = 360, nlon = 720, clean_dir = true)
    isdir(DIR) ? nothing : mkpath(DIR)
    datafile_latlon = DIR * "/remapped_" * "unnamed" * ".nc"
    remap_field_cgll_to_rll(:var, field, DIR, datafile_latlon, nlat = nlat, nlon = nlon)
    new_data, coords = read_remapped_field(:var, datafile_latlon)
    clean_dir && rm(DIR; recursive = true)
    return new_data, coords
end

function get_time(ds)
    if "time" in keys(ds.dim)
        data_dates = Dates.DateTime.(Array(ds["time"]))
    elseif "date" in keys(ds.dim)
        data_dates = strdate_to_datetime.(string.(Int.(Array(ds["date"]))))
    else
        @warn "No dates available in input data file"
        data_dates = [Dates.DateTime(0)]
    end
    return data_dates
end

function get_coords(ds, ::Spaces.ExtrudedFiniteDifferenceSpace)
    data_dates = get_time(ds)
    z = Array(ds["z"])
    return (data_dates, z)
end
function get_coords(ds, ::Spaces.SpectralElementSpace2D)
    data_dates = get_time(ds)
    return (data_dates,)
end

function write_to_hdf5(REGRID_DIR, hd_outfile_root, time, field, varname, comms_ctx)
    t = Dates.datetime2unix.(time)
    hdfwriter =
        InputOutput.HDF5Writer(joinpath(REGRID_DIR, hd_outfile_root * "_" * string(time) * ".hdf5"), comms_ctx)

    InputOutput.HDF5.write_attribute(hdfwriter.file, "unix time", t) # TODO: a better way to write metadata, CMIP convention
    InputOutput.write!(hdfwriter, field, string(varname))
    Base.close(hdfwriter)
end

function reshape_cgll_sparse_to_field!(
    field::Fields.Field,
    in_array::SubArray,
    R,
    ::Spaces.ExtrudedFiniteDifferenceSpace,
)
    field_array = parent(field)

    fill!(field_array, zero(eltype(field_array)))
    Nf = size(field_array, 4)
    Nz = size(field_array, 1)

    # populate the field by iterating over height, then over the sparse vector per face
    for z in 1:Nz
        for (n, row) in enumerate(R.row_indices)
            it, jt, et = (view(R.target_idxs[1], n), view(R.target_idxs[2], n), view(R.target_idxs[3], n)) # cgll_x, cgll_y, elem
            for f in 1:Nf
                field_array[z, it, jt, f, et] .= in_array[row, z]
            end
        end
    end
    # broadcast to the redundant nodes using unweighted dss
    space = axes(field)
    topology = Spaces.topology(space)
    hspace = Spaces.horizontal_space(space)
    Topologies.dss!(Fields.field_values(field), topology)
end

function hdwrite_regridfile_rll_to_cgll(
    FT,
    REGRID_DIR,
    datafile_rll,
    varname,
    space;
    hd_outfile_root = "data_cgll",
    mono = false,
)
    out_type = "cgll"

    outfile = hd_outfile_root * ".nc"
    outfile_root = mono ? outfile[1:(end - 3)] * "_mono" : outfile[1:(end - 3)]
    datafile_cgll = joinpath(REGRID_DIR, outfile_root * ".g")

    meshfile_rll = joinpath(REGRID_DIR, outfile_root * "_mesh_rll.g")
    meshfile_cgll = joinpath(REGRID_DIR, outfile_root * "_mesh_cgll.g")
    meshfile_overlap = joinpath(REGRID_DIR, outfile_root * "_mesh_overlap.g")
    weightfile = joinpath(REGRID_DIR, outfile_root * "_remap_weights.nc")

    if space isa Spaces.ExtrudedFiniteDifferenceSpace
        space2d = Spaces.horizontal_space(space)
    else
        space2d = space
    end

    # If doesn't make sense to regrid with GPUs/MPI processes
    cpu_singleton_context = ClimaComms.SingletonCommsContext(ClimaComms.CPUSingleThreaded())

    topology = Topologies.Topology2D(
        cpu_singleton_context,
        Spaces.topology(space2d).mesh,
        Topologies.spacefillingcurve(Spaces.topology(space2d).mesh),
    )
    Nq = Spaces.Quadratures.polynomial_degree(Spaces.quadrature_style(space2d)) + 1

    space2d_undistributed = Spaces.SpectralElementSpace2D(topology, Spaces.Quadratures.GLL{Nq}())

    if space isa Spaces.ExtrudedFiniteDifferenceSpace
        vert_center_space = Spaces.CenterFiniteDifferenceSpace(Spaces.vertical_topology(space).mesh)
        space_undistributed = Spaces.ExtrudedFiniteDifferenceSpace(space2d_undistributed, vert_center_space)
    else
        space_undistributed = space2d_undistributed
    end
    if isfile(datafile_cgll) == false
        isdir(REGRID_DIR) ? nothing : mkpath(REGRID_DIR)

        nlat, nlon = NCDatasets.NCDataset(datafile_rll) do ds
            (ds.dim["lat"], ds.dim["lon"])
        end
        # write lat-lon mesh
        CCTR.rll_mesh(meshfile_rll; nlat = nlat, nlon = nlon)

        # write cgll mesh, overlap mesh and weight file
        CCTR.write_exodus(meshfile_cgll, topology)
        CCTR.overlap_mesh(meshfile_overlap, meshfile_rll, meshfile_cgll)

        # 'in_np = 1' and 'mono = true' arguments ensure mapping is conservative and monotone
        # Note: for a kwarg not followed by a value, set it to true here (i.e. pass 'mono = true' to produce '--mono')
        # Note: out_np = degrees of freedom = polynomial degree + 1
        kwargs = (; out_type = out_type, out_np = Nq)
        kwargs = mono ? (; (kwargs)..., in_np = 1, mono = mono) : kwargs
        CCTR.remap_weights(weightfile, meshfile_rll, meshfile_cgll, meshfile_overlap; kwargs...)
        CCTR.apply_remap(datafile_cgll, datafile_rll, weightfile, [varname])
    else
        @warn "Using the existing $datafile_cgll : check topology is consistent"
    end

    # read the remapped file with sparse matrices
    offline_outvector, coords = NCDatasets.NCDataset(datafile_cgll, "r") do ds_wt
        (
            # read the data in, and remove missing type (will error if missing data is present)
            offline_outvector = NCDatasets.nomissing(Array(ds_wt[varname])[:, :, :]), # ncol, z, times
            coords = get_coords(ds_wt, space),
        )
    end

    times = coords[1]

    # weightfile info needed to populate all nodes and save into fields with
    #  sparse matrices
    _, _, row_indices = NCDatasets.NCDataset(weightfile, "r") do ds_wt
        (Array(ds_wt["S"]), Array(ds_wt["col"]), Array(ds_wt["row"]))
    end

    target_unique_idxs =
        out_type == "cgll" ? collect(Spaces.unique_nodes(space2d_undistributed)) :
        collect(Spaces.all_nodes(space2d_undistributed))
    target_unique_idxs_i = map(row -> target_unique_idxs[row][1][1], row_indices)
    target_unique_idxs_j = map(row -> target_unique_idxs[row][1][2], row_indices)
    target_unique_idxs_e = map(row -> target_unique_idxs[row][2], row_indices)
    target_unique_idxs = (target_unique_idxs_i, target_unique_idxs_j, target_unique_idxs_e)

    R = (; target_idxs = target_unique_idxs, row_indices = row_indices)

    offline_field = Fields.zeros(FT, space_undistributed)

    offline_fields = ntuple(x -> similar(offline_field), length(times))

    ntuple(
        x -> reshape_cgll_sparse_to_field!(
            offline_fields[x],
            selectdim(offline_outvector, length(coords) + 1, x),
            R,
            space,
        ),
        length(times),
    )

    map(
        x -> write_to_hdf5(REGRID_DIR, hd_outfile_root, times[x], offline_fields[x], varname, cpu_singleton_context),
        1:length(times),
    )
    JLD2.jldsave(joinpath(REGRID_DIR, hd_outfile_root * "_times.jld2"); times = times)
end

using Test
    mktempdir() do REGRID_DIR
    for FT in (Float32, Float64)
        @testset "test hdwrite_regridfile_rll_to_cgll 3d space for FT=$FT" begin
            comms_ctx = ClimaComms.context()
            # Test setup
            R = FT(6371e3)
            space = create_space(FT, nz = 2, ne = 16, R = R)
            # lat-lon dataset
            data = ones(720, 360, 2, 3) # (lon, lat, z, time)
            time = [19000101.0, 19000201.0, 19000301.0]
            lats = collect(range(-90, 90, length = 360))
            lons = collect(range(-180, 180, length = 720))
            z = [1000.0, 2000.0]
            data = reshape(sin.(lats * π / 90)[:], 1, :, 1, 1) .* data
            varname = "sinlat"

            # save the lat-lon data to a netcdf file in the required format for TempestRemap
            datafile_rll = joinpath(REGRID_DIR, "lat_lon_data.nc")
            NCDatasets.NCDataset(datafile_rll, "c") do ds
                NCDatasets.defDim(ds, "lat", size(lats)...)
                NCDatasets.defDim(ds, "lon", size(lons)...)
                NCDatasets.defDim(ds, "z", size(z)...)
                NCDatasets.defDim(ds, "date", size(time)...)

                NCDatasets.defVar(ds, "lon", lons, ("lon",))
                NCDatasets.defVar(ds, "lat", lats, ("lat",))
                NCDatasets.defVar(ds, "z", z, ("z",))
                NCDatasets.defVar(ds, "date", time, ("date",))

                NCDatasets.defVar(ds, varname, data, ("lon", "lat", "z", "date"))
            end

            hd_outfile_root = "data_cgll_test"
            hdwrite_regridfile_rll_to_cgll(
                FT,
                REGRID_DIR,
                datafile_rll,
                varname,
                space,
                mono = true,
                hd_outfile_root = hd_outfile_root,
            )

            # read in data on CGLL grid from the last saved date
            date1 = strdate_to_datetime.(string(Int(time[end])))
            cgll_path = joinpath(REGRID_DIR, "$(hd_outfile_root)_$date1.hdf5")
            hdfreader = InputOutput.HDF5Reader(cgll_path, comms_ctx)
            T_cgll = InputOutput.read_field(hdfreader, varname)
            Base.close(hdfreader)

            # regrid back to lat-lon
            T_rll, _ = cgll2latlonz(T_cgll)

            # check consistency across z-levels
            @test T_rll[:, :, 1] == T_rll[:, :, 2]

            # check consistency of CGLL remapped data with original data
            @test all(isapprox.(extrema(data), extrema(parent(T_cgll)), atol = 1e-2))

            # check consistency of lat-lon remapped data with original data
            @test all(isapprox.(extrema(data), extrema(T_rll), atol = 1e-3))

            # visual inspection
            # Plots.plot(T_cgll) # using ClimaCorePlots
            # Plots.contourf(Array(T_rll)[:,1])
        end
    end
end
charleskawczynski commented 3 months ago

aand, I just deleted my local git folder because of recursive = true.

charleskawczynski commented 3 months ago

I should be able to recover my previous states with dropbox. Just opened https://github.com/CliMA/ClimaCoupler.jl/issues/832

Sbozzolo commented 3 months ago

Hope you managed to recover your folder!

Note that Regridder is a ClimaCoupler module (using ClimaCoreTempestRemap internally), so unfortunately there are more layers to it than what you have in the previous message.

Sbozzolo commented 3 months ago

This is probably a problem with ClimaCoupler, could it be due to this function? https://github.com/CliMA/ClimaCoupler.jl/blob/6266a0fe0ecee4d08af81f3d975cbb99441ab21e/src/Regridder.jl#L87-L113

charleskawczynski commented 3 months ago

Ah, bummer. And perhaps, that function does seem to make a lot of assumptions and use a lot of internals

Hope you managed to recover your folder!

I was, thanks to Dropbox's rewind feature.

charleskawczynski commented 3 months ago

Alright, the reproducer now works

charleskawczynski commented 3 months ago

I git bisected to https://github.com/CliMA/ClimaCore.jl/commit/1b84475b2405b8e0d2014e32fda5a7914cf9794e, so that is indeed the culprit. I'm not sure I'll have the time today to track it down, but it shouldn't take too long with the reproducer and the specific commit.

charleskawczynski commented 3 months ago

I did find that I missed using Nv in a few places:

Base.size(data::AbstractData, i::Integer) = size(data)[i] # add Base.@propagate_inbounds to inline
Base.length(data::DataColumn) = size(parent(data), 1)
Base.size(data::DataColumn) = (1, 1, 1, length(data), 1)

(these are performance optimizations, not bugs)

charleskawczynski commented 3 months ago

Alright, I found it. I missed adding a type parameter in dss!. I'm surprised our tests didn't catch it. I'll fix it and add a patch release. I need to think how we can add a test