tsutterley / pyTMD

Python-based tidal prediction software
https://pytmd.readthedocs.io
MIT License
113 stars 37 forks source link

Optimization of `compute_tide_corrections` with `FES2014` for multiple lat/lons #91

Open robbibt opened 2 years ago

robbibt commented 2 years ago

First of all, congrats @tsutterley on an incredible package... such an amazing resource! I've been looking into using pyTMD for modelling tide heights from FES2014 for our DEA Coastlines coastline mapping work. Essentially, our current process is to:

I've been testing out the compute_tide_corrections as a way to achieve this, passing in the lat/lon of a given 2 x 2 km grid point, and all of the times from my satellite datasets:

import pandas as pd
from pyTMD import compute_tide_corrections

lat, lon = -32, 155
example_times = pd.date_range("2022-01-01", "2022-01-02", freq="1h").values

out = compute_tide_corrections(
    x=lon,
    y=lat,
    delta_time=example_times,
    DIRECTORY="FES2014",
    MODEL="FES2014",
    EPSG=4326,
    TYPE="time series",
    TIME="datetime",
    METHOD="bilinear",
)

This works great, but it's pretty slow: about 38.4 seconds in total for a single point lat/lon. Because I can have up to 100+ lat/lon points in a given study area, this will quickly blow out if I want to apply compute_tide_corrections to multiple points.

Using line_profiler, it appears that by far most of this time (e.g. 38.2 seconds, or over 99%) is taken up in the extract_FES_constants function:

Timer unit: 1e-06 s

Total time: 38.408 s
File: /env/lib/python3.8/site-packages/pyTMD/compute_tide_corrections.py
Function: compute_tide_corrections at line 125

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
...
   277         2   38292108.0 19146054.0     99.7          amp,ph = extract_FES_constants(lon, lat, model.model_file,
   278         1          4.0      4.0      0.0              TYPE=model.type, VERSION=model.version, METHOD=METHOD,
   279         1          4.0      4.0      0.0              EXTRAPOLATE=EXTRAPOLATE, CUTOFF=CUTOFF, SCALE=model.scale,
   280         1          4.0      4.0      0.0              GZIP=model.compressed)
...

Profiling extract_FES_constants, it seems like by far the most amount of time in that function (37.5 seconds) is taken up by read_netcdf_file:

Timer unit: 1e-06 s

Total time: 38.2932 s
File: /env/lib/python3.8/site-packages/pyTMD/read_FES_model.py
Function: extract_FES_constants at line 86

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
...
   158        68   37470465.0 551036.2     97.9              hc,lon,lat = read_netcdf_file(os.path.expanduser(fi),
   159        34         53.0      1.6      0.0                  GZIP=GZIP, TYPE=TYPE, VERSION=VERSION)
...

So essentially, loading the FES2014 files with read_netcdf_file occupies almost all of the time taken to run compute_tide_corrections. For analyses involving many timesteps for a single lat/lon this isn't a problem, as the files only have to be read once. However, for analyses where compute_tide_corrections needs to be called multiple times to model tides for multiple lat/lons, the FES2014 data has to be loaded again and again, leading to extremely long processing times.

Instead of loading the FES2014 files with read_netcdf_file every time compute_tide_corrections is called, could it be possible to give users the option to load the FES files themselves outside of the function, and then pass in the loaded data (i.e. hc, lon, lat) directly to the function via an optional parameter? This would allow users to greatly optimise processing time for analyses that include many lat/lon tide modelling locations.

robbibt commented 2 years ago

Since posting this issue, I have discovered that the drift method gets me closer to what I want, as I can pass in multiple lat/lons as well as multiple times, and avoid the multiple NetCDF reads:

import pandas as pd
from pyTMD import compute_tide_corrections

# Input data (multiple times per point)
example_times = pd.date_range("2022-01-01", "2022-01-30", freq="1D")
point1_df = pd.DataFrame({'lat': -32, 'lon': 155, 'time': example_times})
point2_df = pd.DataFrame({'lat': -33, 'lon': 157, 'time': example_times})
point3_df = pd.DataFrame({'lat': -34, 'lon': 161, 'time': example_times})

# Combine into a single dataframe
points_df = pd.concat([point1_df, point2_df, point3_df])

# Model tide heights using 'drift'
out = compute_tide_corrections(
    x=points_df.lon,
    y=points_df.lat,
    delta_time=points_df.time.values,
    DIRECTORY="FES2014",
    MODEL="FES2014",
    EPSG=4326,
    TYPE="drift",
    TIME="datetime",
    METHOD="bilinear",
)

# Add back into dataframe
points_df['tide_height'] = out

However, because I have static points with multiple timesteps at each, drift still ends up being less efficient than I want because it assumes each time also has a unique lat/lon, which causes the spatial interpolation step to be run for every individual lat/lon/time pair (rather than being interpolated once for each unique point location then re-used for each time given that the point coordinates are the same for all times).

I think for this application (many timesteps for a smaller set of static modelling point locations), the most efficient processing flow might be something like this?

(or alternatively, perhaps some method to detect duplicate/repeated lat/lons, then batch those together to reduce the number of required interpolations...)

tsutterley commented 1 year ago

@robbibt still thinking about the best way to enact these changes. One idea I've been floating is to cache the interpolation objects for each constituent so that won't have to be repeated reads. I'm worried about this being a bit memory intensive though so I need to put in some tests. I've also been reorganizing the code structure lately in #132 and #135. Everything should still be backwards compatible just with some additional warnings.

robbibt commented 12 months ago

Hey @tsutterley, am doing some further optimisations of our tide modelling code as we're moving towards a multi-tide modelling system where we choose the best tide model locally based on comparisons with our satellite data. Because of this, our modelling now takes a lot longer than previously, so I'm looking into trying to parallelise some of the underlying pyTMD code to improve performance.

Our two big bottlenecks are: 1) Loading the tide constituent NetCDF files (which we have largely addressed by clipping the files to a bounding box around Australia) 2) Extracting tide constituents from the NetCDFs

For number 2, I've been able to get a big speed up by parallelising the entire pyTMD.io.*.extract_constants calls across smaller chunks of lat/lon points using concurrent.futures. However, I think there's still some gains to be made as pyTMD.io.*.extract_constants includes the slow NetCDF read step itself, so we're effectively wasting time in each parallel run by loading the same data multiple times.

I know you made some changes to address this last year when I first posted this issue, but I wanted to double check: are the newer pyTMD.io.*.read_constants and pyTMD.io.*.interpolate_constants functions intended to completely replicate the existing functionality in pyTMD.io.*.extract_constants? Or is there any functionality I'd lose by running those two functions instead of pyTMD.io.*.extract_constants?

Ideally, I'd love to do something like this:

tsutterley commented 12 months ago

Hey @robbibt, basically yes that was the plan. The new functions can completely replicate the prior functionality. The difference is that using the new read and interpolate method keeps all of the constituent data in memory. In some cases this may be slower, such as running on a small (possibly distributed) machine. So I've kept both methods.

In cases where you want to run for multiple points with the same data, there is a potential speed up with the new method since (as you mentioned) there's the io bottleneck.

I've thought about switching to dask arrays (probably using xarray) but need to do some testing. I'm completely open to suggestions for squeaking out performance.