Open-EO / openeo-geopyspark-driver

OpenEO driver for GeoPySpark (Geotrellis)
Apache License 2.0
25 stars 4 forks source link

biopar: speed up computation? #682

Closed jdries closed 4 months ago

jdries commented 4 months ago

This code is apparently relatively slow: https://git.vito.be/projects/GEOM/repos/biopar/browse/src/biopar/bioparnnw.py#78

And it logs this: 11 out of the last 11 calls to <function BioParNNW._compute_biopar at 0x7fad3f8fe4c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

The UDF that invokes this also does not cache the biopar class. It's also an option to write a numpy implementation next to the tensorflow one.

Reading the tf.function documentation again, it very much looks like we want to try avoiding the use of a class method, and just replace the variables coming from there with tensorflow constants.

jdries commented 4 months ago

The UDF that invokes this:

j-240206093f0f40e5a8af63554e900ec6


import numpy as np
from typing import Dict
import xarray as xr 
from openeo.udf.xarraydatacube import XarrayDataCube
import tensorflow as tf
from biopar.bioparnnw import BioParNNW

nbiopar_version = '3band'
def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    valid_biopars= ['FAPAR','LAI','FCOVER','CWC','CCC']
    biopar = context.get('biopar', 'FAPAR')
    if biopar not in valid_biopars:
                biopar = 'FAPAR'

    inarr = cube.get_array()
    ds_date = inarr

    from numpy import cos, radians
    scaling_bands = 0.0001
    saa = ds_date.sel(bands='sunAzimuthAngles')
    sza = ds_date.sel(bands='sunZenithAngles')
    vaa = ds_date.sel(bands='viewAzimuthMean')
    vza = ds_date.sel(bands='viewZenithMean')

    B03 = ds_date.sel(bands='B03') * scaling_bands
    B04 = ds_date.sel(bands='B04') * scaling_bands
    B8 = ds_date.sel(bands='B08') * scaling_bands
    g1 = cos(radians(vza))
    g2 = cos(radians(sza))
    g3 = cos(radians(saa - vaa))
    #### FLATTEN THE ARRAY ####
    flat = list(map(lambda arr: arr.flatten(), [B03.values, B04.values,B8.values, g1.values, g2.values, g3.values]))
    bands = np.array(flat)
    #### CALCULATE THE BIOPAR BASED ON THE BANDS #####
    image = BioParNNW(version='3band', parameter=biopar, singleConfig = True).run(bands, output_scale=1,output_dtype=tf.dtypes.float32,minmax_flagging=False)  # netcdf algorithm
    as_image = image.reshape((g1.shape))
    ## set nodata to nan
    as_image[np.where(np.isnan(B03))] = np.nan
    xr_biopar = vza.copy()
    xr_biopar.values = as_image / scaling_bands

    return XarrayDataCube(xr_biopar)  # xarray.DataArray(as_image,vza.dims,vza.coords)
VictorVerhaert commented 4 months ago

I have requested access to the "CGS_S2" folder in order to run the 3_band tests and hopefully recreate the warnings.

jdries commented 4 months ago

@VictorVerhaert I see, these are in fact old paths that can simply be updated with more recent versions: /data/MTDA/TERRASCOPE_Sentinel2/TOC_V2/2017/06/02/S2A_20170602T104021_31UFS_TOC_V210 there should be a similar path to a corresponding fapar product, to use as reference

I would however also recommend the 8 band test: https://git.vito.be/projects/GEOM/repos/biopar/browse/tests/test_8_band.py which uses the same code, but does not have complex test dependencies

VictorVerhaert commented 4 months ago

Implemented the NN in pure numpy. Unfortunately I can't publish my code on the vito git due to insufficient permissions.

I ran the 8 band tests a 100 times and a significant speed improvement can be noticed:

Image

VictorVerhaert commented 4 months ago

Pull request can be found here. For now I have only added a new class so both implementations can be executed and compared in terms of speed, memory usage and results.

Where can I find a script that uses the biopar module? that way I can set up the comparison. (Using the job_id mentioned before I got a process graph but this graph gives me an error)

VictorVerhaert commented 4 months ago

I have tested 3 different biopar implementations for calculating FAPAR on the following extents:

bbox = {"west": 5.0, "south": 51.0, "east": 5.2, "north": 51.2}
date = ["2019-01-01", "2023-12-31"]
  1. original bioparnnw without caching. (current implementation and usage) image = BioParNNW(version='3band', parameter=biopar, singleConfig = True).run(bands, output_scale=1, output_dtype=tf.dtypes.float32, minmax_flagging=False) # netcdf algorithm As the BioParNNW class is instantiated each time, the abovementioned retracing warning is produced. This implentation resulted in an OOM each time it was ran on this large extent.

  2. original bioparnnw with lru_caching.

    @lru_cache(maxsize=6)
    def get_biopar()-> BioParNNW:
    return BioParNNW(version='3band', parameter='FAPAR', singleConfig = True)
    ...
    #### CALCULATE THE BIOPAR BASED ON THE BANDS #####
    image = get_biopar().run(bands, output_scale=1,
                                output_dtype=tf.dtypes.float32,
                                minmax_flagging=False)  # netcdf algorithm

    The retracing warning is no longer produced and the batch job finishes with following usage stats: cpu: 32.4k sec, duration: 2026 sec, memory: 188M mb-sec and max_executor_memory: 4.44gb, credits: 66

  3. bioparnp (numpy implementation) without caching. (caching gave similar results) usage stats: cpu: 27.4k sec, duration: 2184 sec, memory: 166M mb-sec, max_executor_memory: 4.42gb, credits: 57

my conclusion: np implementation gives slightly better results compared to tensorflow with caching, and is less dependent of a users implementation.

VictorVerhaert commented 4 months ago

Should we add a @depricated to the BioParNNW class and tell users to use the BioParNp class instead, or just leave it to the users? @jdries