jgrss / geowombat

GeoWombat: Utilities for geospatial data
https://geowombat.readthedocs.io
MIT License
182 stars 10 forks source link

Ml no data #270

Closed mmann1123 closed 2 months ago

mmann1123 commented 1 year ago

For any classifier other than gausian, the current predict workflow requires nodata=0 for gw.open. Not sure why this is.

For instance

import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
import geowombat as gw
from geowombat.data import l8_224078_20200518, l8_224078_20200518_polygons
from geowombat.ml import   fit_predict
from sklearn.preprocessing import LabelEncoder
import geopandas as gpd

le = LabelEncoder()

# The labels are string names, so here we convert them to integers
labels = gpd.read_file(l8_224078_20200518_polygons)
labels['lc'] = le.fit(labels.name).transform(labels.name)

# Use a data pipeline
pl = Pipeline([  
                ('clf', RandomForestClassifier())])

fig, ax = plt.subplots(dpi=200,figsize=(5,5))

# Fit the classifier
with gw.config.update(ref_res=150):
    with gw.open(l8_224078_20200518, nodata=0) as src:
        y = fit_predict(src, pl, labels, col="lc")
        y.plot(robust=True, ax=ax)
plt.tight_layout(pad=1)

works only if nodata=0

jgrss commented 1 year ago

@mmann1123 in this line, NaN values are dropped from the predictors based on the DataArray metadata 'no data' value. An attempt is made to collect the nodata value from the raster file metadata. However, if the internal nodata value is None and the user has not specified a nodata value, then the nodata value is set as NaN.

In the 'ml' module, zeros are left in the predictor array here if the true nodata values are 0 and the DataArray nodata value is NaN. When you do not set nodata and this image does not have a nodata metadata tag, you get:

<xarray.DataArray (band: 3, y: 372, x: 408)>
dask.array<open_rasterio-f620f650fb98ea6ce82cafe95ee46b55<this-array>, shape=(3, 372, 408), dtype=uint16, chunksize=(3, 256, 256), chunktype=numpy.ndarray>
Coordinates:
  * band     (band) int64 1 2 3
  * x        (x) float64 7.174e+05 7.176e+05 7.177e+05 ... 7.783e+05 7.785e+05
  * y        (y) float64 -2.777e+06 -2.777e+06 ... -2.833e+06 -2.833e+06
Attributes: (12/13)
    transform:           (150.0, 0.0, 717345.0, 0.0, -150.0, -2776995.0)
    crs:                 32621
    res:                 (150.0, 150.0)
    is_tiled:            0
    nodatavals:          (nan, nan, nan)
    _FillValue:          nan
    ...                  ...
    offsets:             (0.0, 0.0, 0.0)
    filename:            /mnt/c/Users/jbgra/Documents/code/geowombat/src/geow...
    resampling:          nearest
    AREA_OR_POINT:       Area
    _data_are_separate:  0
    _data_are_stacked:   0

When you set nodata you get:

<xarray.DataArray (band: 3, y: 372, x: 408)>
dask.array<open_rasterio-f620f650fb98ea6ce82cafe95ee46b55<this-array>, shape=(3, 372, 408), dtype=uint16, chunksize=(3, 256, 256), chunktype=numpy.ndarray>
Coordinates:
  * band     (band) int64 1 2 3
  * x        (x) float64 7.174e+05 7.176e+05 7.177e+05 ... 7.783e+05 7.785e+05
  * y        (y) float64 -2.777e+06 -2.777e+06 ... -2.833e+06 -2.833e+06
Attributes: (12/13)
    transform:           (150.0, 0.0, 717345.0, 0.0, -150.0, -2776995.0)
    crs:                 32621
    res:                 (150.0, 150.0)
    is_tiled:            0
    nodatavals:          (0, 0, 0)
    _FillValue:          0
    ...                  ...
    offsets:             (0.0, 0.0, 0.0)
    filename:            /mnt/c/Users/jbgra/Documents/code/geowombat/src/geow...
    resampling:          nearest
    AREA_OR_POINT:       Area
    _data_are_separate:  0
    _data_are_stacked:   0

which makes the 'ml' masking work and drop zeros.

jgrss commented 1 year ago

At first, I thought that I could not replicate your issue. But then I saw that the issue is in the predictions. After examining the predictor variables with and without a proper nodata value, I think that the model is training on a large set of zeros when there is a mismatch between the data nodata and the DataArray attribute/metadata nodata value.