alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
69 stars 12 forks source link

Density channel stripes in `ConvNP` SetConv context encoding when lat/lon resolutions differ in gridded data #77

Closed tom-andersson closed 8 months ago

tom-andersson commented 8 months ago

See title. The issue is that the size/scale of the SetConv Gaussian blobs is currently set to the finer of the two resolutions by the compute_xarray_resolution method. See example code and encoding figure below. It's possible that these stripe artefacts could filter through to the ConvNP predictions.

import deepsensor.torch
from deepsensor.data.processor import DataProcessor
from deepsensor.data.loader import TaskLoader
from deepsensor.model.convnp import ConvNP
from deepsensor.train.train import Trainer

import xarray as xr
import pandas as pd
import numpy as np

# Load raw data
ds_raw = xr.tutorial.open_dataset("air_temperature")

# Normalise data
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)

# Set up task loader
task_loader = TaskLoader(context=ds, target=ds)

# Set up model with large internal resolution to sample the SetConv context encoding finely
model = ConvNP(data_processor, task_loader, points_per_unit=300)

# Generate Task with all data passed as context
task = task_loader("2014-01-01", "all", "all")

# Plot context encoding
fig = deepsensor.plot.context_encoding(model, task, task_loader, size=5)
fig.savefig("encoding.png")

encoding

tom-andersson commented 8 months ago

Using the coarser of the two resolutions isn't ideal either, because this overly blurs the data:

encoding

I suggest we add a new feature to the DataProcessor which checks for xarray data with differing resolutions and resamples the array if so. This will ensure the Gaussian blobs at grid cell locations are spaced uniformly, removing the stripes.

tom-andersson commented 8 months ago

Ah, I just realised the xr.tutorial.open_dataset('air_temperature') dataset actually has equal resolutions in lat & lon - 2.5 degrees in both cases. The stripes were actually caused by the DataProcessor squashing the data more in the lon-direction than the lat-direction when normalising the coordinates to [0, 1]. See the

>>> print(data_processor)
DataProcessor with normalisation params:
{'air': {'method': 'mean_std',
         'params': {'mean': 281.255126953125, 'std': 16.320409774780273}},
 'coords': {'time': {'name': 'time'},
            'x1': {'map': (15.0, 75.0), 'name': 'lat'},
            'x2': {'map': (200.0, 330.0), 'name': 'lon'}}}

So, in fact, the stripes are caused by the DataProcessor inducing differing spatial dimension resolutions in normalised space. In other words, the DataProcessor maps data spanning a rectangular region into a square region, squishing the relative resolution in the 'longer' dimension.

Note: The stripe artefacts would also occur if the raw data itself had differing dimension resolutions, as assumed in the OP, but this will rarely be the case because data is typically the same res in both spatial dimensions.

If we instead explicitly ensure the DataProcessor linear coordinate mappings perform the same scaling in both spatial dimensions, we resolve the stripe issue: data_processor = DataProcessor(x1_name="lat", x2_name="lon", x1_map=(15, 75), x2_map=(200, 275)). Encoding below:

encoding

I will revise my suggestion about resampling gridded data in the DataProcessor: instead we should ensure the coord mappings use the same scaling and raise a warning if the user explicitly does otherwise.

tom-andersson commented 8 months ago

Closed by https://github.com/tom-andersson/deepsensor/commit/1c267197a9794302c1f1bb33dd3aa9b44360c8ed. Now, the aspect ratio of the data is preserved when auto-normalising coordinates with the DataProcessor. If the user manually provides x1_map and x2_map to DataProcessor, a warning is raised if the two mappings would warp the aspect ratio (thus producing ConvNP encoding stripes).