geopandas / dask-geopandas

Parallel GeoPandas with Dask
https://dask-geopandas.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
505 stars 45 forks source link

ValueError: 'left_df' should be GeoDataFrame, got <class 'tuple'> #308

Closed ManuelPopp closed 2 months ago

ManuelPopp commented 2 months ago

Hey there,

I ran into some issues/weird behaviour.

The issue

I get a ValueError when running sjoin with two dask_geopandas.expr.GeoDataFrames. THere are no tuples involved. I also checked the data types. Everything seems fine, but it does not run.

The code I run is as follows (full version at the end):

points = coordinates.set_crs("EPSG:4326", allow_override = True)
assert isinstance(
    points, dg.GeoDataFrame
    ), f"Expected GeoDataFrame, got {type(points)}"

areas = polygons.set_crs("EPSG:4326", allow_override = True)
assert isinstance(
    areas, dg.GeoDataFrame
    ), f"Expected GeoDataFrame, got {type(polygons)}"

result = dg.sjoin(
    points, areas, how = "inner", predicate = "intersects"
    )
print(f"{type(points)=}, {points.crs=}")
print(points.head())

print(f"{type(areas)=}, {areas.crs=}")
print(areas.head())

Here, coordinates and polygons are two dask_geopandas.expr.GeoDataFrames. I can share example data if required.

This is what I get when I print out the data type, crs, and .head() of the input data sets:

type(points)=<class 'dask_geopandas.expr.GeoDataFrame'>, points.crs=<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

         lat       lon                   geometry
0  24.213465  45.50688  POINT (45.50688 24.21347)
1  24.213465  45.50688  POINT (45.50688 24.21347)
2  24.213465  45.50688  POINT (45.50688 24.21347)
3  24.213465  45.50688  POINT (45.50688 24.21347)
4  24.213465  45.50688  POINT (45.50688 24.21347)

type(areas)=<class 'dask_geopandas.expr.GeoDataFrame'>, areas.crs=<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

  DamageType   EventDate  EventMonth  EventYear  preciseDate      Source                                           geometry
0      Storm  2017-11-10        11.0     2017.0         True  FORWIND_v2  MULTIPOLYGON (((11.41737 53.26798, 11.41775 53...
1      Storm  2017-11-10        11.0     2017.0         True  FORWIND_v2  MULTIPOLYGON (((11.48742 53.26792, 11.48761 53...
2      Storm  2017-11-10        11.0     2017.0         True  FORWIND_v2  MULTIPOLYGON (((11.47358 53.26693, 11.47372 53...
3      Storm  2017-11-10        11.0     2017.0         True  FORWIND_v2  MULTIPOLYGON (((11.49743 53.26749, 11.4977 53....
4      Storm  2017-11-10        11.0     2017.0         True  FORWIND_v2  MULTIPOLYGON (((11.45444 53.2995, 11.45446 53....

What I tried/found out

The error does not occur, if I apply .compute() to the input variables before running sjoin. This is even the case when I turn the computed GeoPandas.GeoDataFrames back into dask_GeoPandas, i.e., into the same dask_geopandas.expr.GeoDataFrame as before.

While this workaround may be suitable for small data sets (then, however, the question would be: Why use dask in the first place?), computing some intermediate output will take quite a long time for a huge data set.

Environment

Full code

#=============================================================================|
# Imports
#-----------------------------------------------------------------------------|
# General
import os
import json
import pickle as pk
from pathlib import Path
from tqdm import tqdm
import warnings
from dataclasses import dataclass

#-----------------------------------------------------------------------------|
# Data handling
import pandas as pd# Maybe use vaex once it is released, to be scalable for huge data sets
from pyproj.crs import CRS
import geopandas as gpd
import dask.dataframe as dd
from dask import config
from dask.diagnostics import ProgressBar

#-----------------------------------------------------------------------------|
# Geo
import dask_geopandas as dg
from shapely.geometry import Point

#=============================================================================|
# Functions (separated for this performance critical application)
#-----------------------------------------------------------------------------|
def lower_median(values):
    sorted_values = sorted(values)
    n = len(sorted_values)
    return sorted_values[n // 2] if n % 2 == 1 else sorted_values[n // 2 - 1]

def _extract_coordinates(row):
    try:
        data = json.loads(row[".geo"])
        lon, lat = data["coordinates"]
        return pd.Series([lat, lon], index = ["lat", "lon"])

    except (json.JSONDecodeError, KeyError):
        return pd.Series([None, None], index = ["lat", "lon"])

def _create_geometry(df):
    df = df.copy()
    df.loc[:,"geometry"] = df.apply(
        lambda row: Point(row["lon"], row["lat"]), axis = 1
        )
    return df

def _filter_median_by_ndvi(group):
    ndvi = (group["B8"] - group["B4"]) / (group["B8"] + group["B4"])
    median_ndvi = lower_median(ndvi)
    group_median = group[ndvi == median_ndvi].iloc[[0]]
    group_median.columns = group.columns

    return group_median.reset_index(drop = True)

#=============================================================================|
# Classes
#-----------------------------------------------------------------------------|
# Settings
@dataclass
class Settings:
    db_options = {
            "parquet": ".parquet",
            "hdf5": ".h5",
            "persist": ".dask",
            "csv": ".csv",
            "sql": ".sqlite"
        }
    partition_size = 25e6

#config.set({"dataframe.query-planning": False})

#-----------------------------------------------------------------------------|
# Main
class DataBase:
    def __init__(self):
        self.version = "0.0.1"
        self.dataX_dir = self.datay_dir = os.getcwd()
        self.dataX = pd.DataFrame()
        self.datay = pd.DataFrame()
        self._data_y_geodb = None
        self.db_options = Settings.db_options
        self._drefX = self.DataBaseReference(None)
        self._drefy = self.DataBaseReference(None)

    def __repr__(self):
        return f"DataBase(version='{self.version}', size='{self.size}')"

    class DataBaseReference:
        def __init__(self, path):
            self.path = "" if path is None else path
            self.db_options = Settings.db_options

        @property
        def memory_usage(self):
            if os.path.isfile(self.path):
                return os.path.getsize(self.path)
            else:
                return None

        @property
        def is_valid(self):
            return os.path.isfile(self.path)

        def load(self, db_type = None, **kwargs):
            db_types = {value: key for key, value in self.db_options.items()}
            db_types.update({".hdf5": "hdf5"})

            db_ext = os.path.splitext(path)
            if db_ext not in db_types.keys() and db_type is None:
                raise ValueError(
                    "Cannot infer data base type from file name " +
                    f"{os.path.basename(path)}. Please provide a data base type."
                    )

            if db_type is None:
                db_type = db_types[str(db_ext)]

            if db_type == "parquet":
                return dd.read_parquet(self.path, **kwargs)
            elif db_type == "hdf5":
                return dd.read_hdf(self.path, key = "df", **kwargs)
            elif db_type == "dask":
                return dd.from_disk(self.path, **kwargs)
            elif db_type == "csv":
                if "columns" in kwargs.keys():
                    usecols = kwargs.pop("columns")
                    kwargs["usecols"] = usecols

                return dd.read_csv(self.path,**kwargs)
            elif db_type == "sql":
                import sqlite3
                return dd.read_sql_table(
                    "data_table", "sqlite:///" + self.path,
                    index_col = "id", **kwargs
                    )
            else:
                raise ValueError(f"Invalid data base type {db_type}.")
    @property
    def size(self):
        mem_x = self.dataX.memory_usage
        if isinstance(self.dataX, self.DataBaseReference):
            loc_x = "disk"
        else:
            loc_x = "memory"

        mem_y = self.datay.memory_usage
        if isinstance(self.datay, self.DataBaseReference):
            loc_y = "disk"
        else:
            loc_y = "memory"

        return {
            "X": {
                "size": mem_x,
                "location": loc_x
            },
            "y": {
                "size": mem_y,
                "location": loc_y
            }
        }

    def save(self, path = os.getcwd()):
        with open(path, "wb") as f:
            pk.dump(self.__dict__, file = f)

    def load(self, path = os.getcwd()):
        with open(path, "rb") as f:
            self.__dict__.update(pk.load(f))

    def read_csv_from_directory(self, directory, single = False):
        self.dataX_dir = Path(directory)
        csv_files = [
            file.resolve() for file in self.dataX_dir.rglob(
                "*.csv"
                ) if file.is_file()
            ]

        files = []
        for file in csv_files:
            try:
                pd.read_csv(file, nrows = 1)
                files.append(file)
            except pd.errors.EmptyDataError:
                warnings.warn(
                    f"Empty or malformed file {file} encountered and skipped."
                    )
            except pd.errors.ParserError as e:
                warnings.warn(f"Failed to parse file {file}. Message: {e}")

        if single:
            pbar = tqdm(files, desc = "Reading data chunks from directory")
            for file in pbar:
                try:
                    df = pd.read_csv(file)
                    try:
                        dataX = pd.concat(
                            [dataX, df],
                            ignore_index = True
                            )
                        memory_usage = dataX.memory_usage(
                            deep = True
                            ).sum() / (1024 ** 2)

                        if memory_usage < 1024:
                            unit = "MB"
                        else:
                            memory_usage = memory_usage / 1024
                            unit = "GB"
                    except Exception as e:
                        warnings.warn(f"\n{e}")

                except pd.errors.EmptyDataError:
                    warnings.warn(
                        f"\nEmpty file {file} encountered and skipped."
                        )
                except pd.errors.ParserError as e:
                    warnings.warn(
                        f"\nFailed to parse file {file}. Message: {e}"
                        )
                except Exception as e:
                    warnings.warn(
                        f"\nFailed to process file {file}. Message: {e}"
                        )

                pbar.set_postfix(mem_used = f"{memory_usage:.2f} {unit}")
            pbar.close()
        else:
            dataX = dd.read_csv(
                files,
                parse_dates = ["date"],
                assume_missing = False
                )

            dataX.columns = dataX.columns.str.strip()

        dataX = dataX[dataX["QA60"] == 0]

        meta = pd.DataFrame({
            "lat": pd.Series(dtype = "float64"),
            "lon": pd.Series(dtype = "float64")
            })

        dataX[["lat", "lon"]] = dataX.apply(
            _extract_coordinates,
            axis = 1,
            meta = meta
            )

        dataX["year_month"] = dataX["date"].dt.to_period("M")

        meta = dataX.head(1)

        dataXg = dataX.groupby(
            ["lat", "lon", "year_month"]
            ).apply(_filter_median_by_ndvi, meta = meta)

        self.dataX = dataXg.reset_index(drop = True)

    def materialise_X(self, inplace = False):
        if isinstance(self.dataX, self.DataBaseReference):
            self.dataX = self.dataX.load()

        if isinstance(self.dataX, dd.DataFrame):
            with ProgressBar():
                dataX = self.dataX.compute()
        elif isinstance(self.dataX, pd.DataFrame):
            dataX = self.dataX
        else:
            raise ValueError(f"Cannot handle type {type(self.dataX)=}.")

        if inplace:
            self.dataX = dataX
        else:
            return dataX

    def _get_database_y(self, disturbance_gpkg, coverage_gpkg):
        self.datay_dir = Path(disturbance_gpkg).parent
        self.disturbance = gpd.read_file(
            disturbance_gpkg
            )
        self.coverage = gpd.read_file(
            coverage_gpkg, layer = "Union"
            ).explode(index_parts = False)
        tmp = gpd.sjoin(
            self.disturbance,
            self.coverage,
            how = "inner", op = "intersects",
            lsuffix = "left", rsuffix = "right"
            )
        tmp["overlap"] = tmp.geometry.intersection(tmp["geometry_right"]).area
        tmp = tmp.sort_values(
            by = ["index_left", "overlap"], ascending = [True, False]
            )
        tmp = tmp.drop_duplicates(subset = ["index_left"], keep = "first")

        self.disturbance["polygon_id"] = self.disturbance.index.map(
            tmp.set_index("index_left")["index_right"]
            )

    def link_database_X(self, path):
        self._drefX = self.DataBaseReference(path)

    def link_database_y(self, path):
        self._drefy = self.DataBaseReference(path)

    def load_database_X(self, path = None):
        if path is not None:
            self.link_database_X(path)

        self.dataX = self._drefX.load()

    def load_database_y(self, path = None):
        if path is not None:
            self.link_database_y(path)

        self.datay = self._drefy.load()

    def memory_to_disk(self,
                       directory = None, db_format = "parquet", ds = "both",
                       free_mem = False
                       ):
        '''Save data to disk.

        Args:
            directory (str, optional): Output directory.
            Defaults to None.
            db_format (str, optional): Output data base type. Defaults
            to "parquet".
            ds (str, optional): _description_. Defaults to "both".
            free_mem (bool, optional): Free up memory (store data as a
            reference to the exported data base). Defaults to False.

        Raises:
            ValueError: Invalid input parameter.
        '''
        if db_format.lower() not in self.db_options.keys():
            opts = ", ".join(self.db_options.keys())
            raise ValueError(
                f"Invalid option {db_format=}. Valid options are: {opts}."
                )

        if directory is None:
            out_fileX = self.dataX_dir.joinpath(
                "dataX" + self.db_options[db_format]
                )
            out_filey = self.datay_dir.joinpath(
                "datay" + self.db_options[db_format]
                )

        # Export X data set
        if isinstance(
            self.dataX, dd.DataFrame
            ) and ds.lower() in ["x", "both"]:
            with ProgressBar():
                if db_format.lower() == "parquet":
                    self.dataX.to_parquet(out_fileX, write_index = False)
                elif db_format.lower() == "hdf5":
                    self.dataX.to_hdf(out_fileX, key = "df", mode = "w")
                elif db_format.lower() == "persist":
                    self.dataX = self.dataX.to_disk(out_fileX)
                elif db_format.lower() == "csv":
                    self.dataX.to_csv(
                        out_fileX, single_file = True, index = False
                        )
                else:
                    import sqlite3
                    conn = sqlite3.connect(out_fileX)
                    self.dataX.to_sql(
                        "data_table", conn, if_exists = "replace", index = False
                        )

            self._drefX = self.DataBaseReference(out_fileX)

            if free_mem:
                self.to_reference("X")

    def to_reference(self, ds):
        '''Replace data set in memory by a reference to an exported data set
        on disk.

        Args:
            ds str: Either "x", "y", or "both" for x, y, or both data sets.

        Raises:
            ValueError: No data set on disk.
        '''
        if ds.lower() in ["x", "both"]:
            if not os.path.isfile(self._drefX.path):
                raise ValueError(
                    f"No valid reference for X data set. Export data first."
                    )
            else:
                self.dataX = self._drefX

        if ds.lower() in ["y", "both"]:
            if not os.path.isfile(self._drefy.path):
                raise ValueError(
                    f"No valid reference for y data set. Export data first."
                    )
            else:
                self.datay = self._drefy

    def sample_coordinates(self):
        if isinstance(self.dataX, self.DataBaseReference):
            coordinates = self.dataX.load(columns = ["lat", "lon"])
        else:
            coordinates = self.dataX[["lat", "lon"]]

        if isinstance(coordinates, pd.DataFrame):
            size = coordinates.memory_usage(deep = True).sum()
            coordinates = dd.from_pandas(
                coordinates, npartitions = size // Settings.partition_size
                )

        return coordinates

    def get_disturbances(self, file_y = None):
        coordinates = self.sample_coordinates()

        if self._data_y_geodb is None:
            self._data_y_geodb = file_y

        polygons = dg.read_file(
            self._data_y_geodb,
            npartitions = max(
                os.path.getsize(
                    self._data_y_geodb
                    ) // Settings.partition_size,
                1
                ),
            )

        meta = pd.DataFrame({
            "lat": pd.Series(dtype = "float64"),
            "lon": pd.Series(dtype = "float64"),
            "geometry": gpd.GeoSeries(dtype = "geometry")
        })

        coordinates = coordinates.map_partitions(
            _create_geometry, meta = meta
            )

        coordinates = dg.from_dask_dataframe(
            coordinates, geometry = "geometry"
            )

        points = coordinates.set_crs(CRS.from_epsg(4326), allow_override = True)
        assert isinstance(
            points, dg.GeoDataFrame
            ), f"Expected GeoDataFrame, got {type(points)}"

        areas = polygons.set_crs(CRS.from_epsg(4326), allow_override = True)
        assert isinstance(
            areas, dg.GeoDataFrame
            ), f"Expected GeoDataFrame, got {type(polygons)}"

        result = dg.sjoin(
            points, areas, how = "inner", predicate = "intersects"
            )
        print(f"{type(points)=}, {points.crs=}")
        print(points.head())

        print(f"{type(areas)=}, {areas.crs=}")
        print(areas.head())

        p = points.compute()
        a = areas.compute()
        p = dg.from_geopandas(p)
        a = dg.from_geopandas(a)
        #result = dg.sjoin(
        #    p, a, how = "inner", predicate = "intersects"
        #    )
        return result

db = DataBase()
os.chdir("D:/tmp")
db.read_csv_from_directory("D:/tmp2")
r = db.get_disturbances("C:/Users/poppman/switchdrive/PhD/prj/rsm/dat/Disturbances.gpkg")
r.head()
TomAugspurger commented 2 months ago

This may have been fixed by https://github.com/geopandas/dask-geopandas/pull/307. Can yo try installing geopandas from GitHub and reporting whether you still see the error?

ManuelPopp commented 2 months ago

I ran

pip install git+https://github.com/geopandas/geopandas.git

Successfully installed geopandas-1.0.1+18.g3d0ff15

Which alone did not resolve the error. However, combined with

pip install git+https://github.com/geopandas/dask-geopandas

Successfully installed dask-geopandas-0.4.1+3.g069f0c1

it seems the issue is gone! Thank you for the quick response to this issue!