earthdaily / earthdaily-python-client

EarthDaily python client
https://earthdaily.github.io/earthdaily-python-client/
MIT License
11 stars 6 forks source link

Register images in time series cube #12

Open robmarkcole opened 10 months ago

robmarkcole commented 10 months ago

I've observed pixel shifts of order a couple of pixels on Sentinel 2 and created the following solution which can be integrated/adapted:

from scipy.ndimage import shift as nd_shift
from skimage.registration import phase_cross_correlation
from scipy import fftpack

def low_pass_filter_fourier(image, cutoff_radius):
    """
    Apply a low-pass filter to an image in the Fourier domain.

    The function performs the Fourier transform on the input image, multiplies 
    it with a low-pass filter mask, and then performs an inverse Fourier 
    transform to get the filtered image.

    Parameters:
    ----------
    image : np.ndarray
        The input 2D image to be filtered.

    cutoff_radius : float
        The cutoff radius for the low-pass filter in the Fourier domain.
        This value should be between 0 and 0.5, where 0.5 corresponds to
        the Nyquist frequency.

    Returns:
    -------
    np.ndarray
        The filtered image in the spatial domain.

    Example:
    -------
    >>> image = np.random.rand(100, 100)
    >>> filtered_image = low_pass_filter_fourier(image, 0.2)
    """
    # Fourier Transform
    F = fftpack.fftshift(fftpack.fft2(image))

    # Generate low-pass filter mask
    rows, cols = image.shape
    x = np.fft.fftfreq(rows)
    y = np.fft.fftfreq(cols)
    x, y = np.meshgrid(x, y, indexing='ij')
    radius = np.sqrt(x ** 2 + y ** 2)

    # Apply low-pass filter
    filter_mask = np.fft.fftshift(radius <= cutoff_radius)
    F_filtered = F * filter_mask

    # Inverse Fourier Transform
    image_filtered = np.fft.ifft2(np.fft.ifftshift(F_filtered)).real

    return image_filtered

def register_xarrays(xarray1: xr.DataArray, xarray2: xr.DataArray) -> xr.DataArray:
    """Register xarray2 onto xarray1 using the 'red' band for registration."""

    array1 = xarray1.sel(band='red').values
    array2 = xarray2.sel(band='red').values

    array1 = low_pass_filter_fourier(array1, 0.2)
    array2 = low_pass_filter_fourier(array2, 0.2)

    shift, _, _ = phase_cross_correlation(array1, array2, upsample_factor=100)
    print(shift)

    corrected_xarray2 = xarray2.copy()

    # Apply shift to all bands
    for band in xarray2['band'].values:
        corrected_xarray2.loc[dict(band=band)] = nd_shift(xarray2.sel(band=band).values, shift[:2])

    return corrected_xarray2
nkarasiak commented 10 months ago

Wow really nice coregistration function to implement for sure !