pytroll / pyresample

Geospatial image resampling in Python
http://pyresample.readthedocs.org
GNU Lesser General Public License v3.0
350 stars 94 forks source link

Dask-ify Elliptical Weighted Averaging (EWA) resampling #281

Closed djhoese closed 3 years ago

djhoese commented 4 years ago

Overview

This is a continuation of the discussion started during the Pytroll Contributor Week Spring 2020 between me, @pnuu, @mraspaud, and @sfinkens (feel free to unsubscribe). This issue is me recording some of what was discussed and my thoughts going forward as I start to work on this. The solution developed for this issue may result in utilities that could be used for improvements to gradient search and other resampling algorithms in pyresample or even as a general dask utility.

This is mostly a brain dump. You've been warned.

Dask-ified resampling is hard

Dask brings a lot of benefits to the type of processing that the core devs generally have to do. Most of our work involves per-pixel operations which work really well with the chunked arrays that dask uses/creates. Resampling data from one geographic grid to another is not one of the operations that fits well with dask. It breaks with what dask would prefer: "input chunk -> output chunk". Even in the complicated cases you have "these arrays with these dimensions IJK -> these arrays with these dimensions XY"; some sort of predictable relationship, preferably 1:1.

Resampling in pyresample generally involves going from one 2D image array to another 2D image array with completely different dimensions. There isn't really anything that tells us input chunk 5 will go in to output chunk 8. Very often these arrays are on completely different coordinate systems. The new implementation of EWA resampling discussed here is how I/we plan to overcome this challenge.

Gradient Search

@mraspaud has developed a dask-friendly version of the "gradient search" algorithm. It does this by generating a "stack" of chunks to process. For example, if we had an input array with 2 chunks in the row dimension and 3 chunks in the column dimension and we had an output array with 2 chunks in the row dimension and 2 chunks in the column, our stack would be the combination of every input chunk with every output chunk.

stack_0_0 = resample_alg(in_0, out_0)
stack_1_0 = resample_alg(in_1, out_0)
stack_2_0 = resample_alg(in_2, out_0)
...
stack_5_3 = resample_alg(in_5, out_3)

So to get the result for one output chunk (chunk 0) we combine all the stack_X_0 variables (ex. the maximum value for each pixel). This explanation works better as a drawing, but I don't have one available right now. @mraspaud correct me if I'm wrong.

The Good and Bad of Gradient Search

Gradient search's implementation is great because we are operating per-chunk and can limit how much memory is used by limiting the number of workers and size of chunks. It also shouldn't use anything that isn't serializable (something that can't be sent between multiprocess workers). This isn't true for KDTree-based nearest neighbor resampling right now.

Storing this stack is currently done as a 3D dask array which means that every stack "segment" has to have the same shape. It also means that dask is going to compute every one of these stack arrays because there isn't currently a way to say "based on some math we know stack_2_0 isn't actually needed because we know in_2 doesn't geographically overlap with out_0. This is wasted processing and wasted space.

Lastly, I'm not sure this can be worked around but this type of algorithm means a lot of communication between dask workers on a multi-node system. Your two best cases are either all input arrays are processed on the same node in which case output chunks have to be sent to another worker to be processed or you have all output chunks processed on the same node which means input chunks have to be sent to all workers.

I left leaving the PCW discussion hoping that something better (even if only slightly) could be developed. Especially considering the algorithm I want to perform this type of work on has some slightly different contraints.

Current EWA

The current implementation of the Elliptical Weighted Averaging (EWA) algorithm was designed for MODIS data in the ms2gt (MODIS Swath to Grid Toolbox) tool. It was made available through the combination of the ll2cr (lat/lon to col/row) and fornav (forward navigation, I think) tools. I rewrote it for the CSPP Polar2Grid project in a combination of C++ and cython and then later ported this code to pyresample. It can be used from Satpy currently, but uses a quick hack of wrapping all the code in dask delayed functions instead of operating on data chunk by chunk.

EWA is only meant to operate on scan-based polar-orbiter instruments like MODIS and VIIRS where there are multiple rows of data per scan. So each scan is the same number of rows (ex. 10 rows) and the same number of columns (ex. 3200 columns) where the scan spans the entire width of the image array (num_cols == scan_cols). EWA has been shown to produce "OK" results for AVHRR (1 row per scan?) when an entire swath is treated as one giant scan but this is a special case I'd say. The algorithm, without any extra flags, goes through the following operations:

  1. Take one scan of navigation data and compute "EWA Parameters" based on the geometry of this scan for every column of the scan reference.
  2. Iterate over the input pixels' locations (these are actual the target AreaDefinition row/col indexes) reference. a. Find nearby pixels within an ellipse b. Calculate some factors based on the EWA parameters for this column and the distance of the nearby pixels when mapped to the output grid. reference c. Iterate over one or more image arrays. i. Create 2 sums the size of the target grid for each image array: one of the weights and one of the weight * image value reference.
  3. Iterate over each image's accumulation and weights and do output = accumulation / weights reference.

The Good and Bad of EWA

So some nice things about this algorithm are steps 2a-2b since we only have to compute those once per source/target area combination. The downside of doing this is that step 2c is working with the input image array (possibly large), the output accumulation array (possibly large), and the output weights array (possibly large). Then in step 3 we have to have another array the size of the output area for the final result; another thing that is possibly large. So the thing I've never really profiled/tested very well is where is the tipping point between the saved processing of 2a/2b and the performance lost and memory used by having to swap being image arrays (2c and 3). I'd guess in most cases that processing one image array at a time through EWA will almost always be faster than running multiple channels at a time. I could be wrong.

Proposed solution

Without doing much research, our discussion lead to wondering if the gradient search technique could be modified to build a series of dask tasks instead of building the stack. This would mean things don't all have to be the same shape (technically) and that we could have more control over how these are created. One element of control would be some sort of pre-check to determine if an input/output chunk combination should produce a task. For example "this input chunk's bounding box does not intersect this output chunk's bounding box, don't create a task to work on them".

How this would get coded...I don't know.

Other Considerations

As mentioned EWA does things per-scan. This could result in a lot of chunks if we naively treated each scan as a separate dask task to be processed. Martin brought up that there shouldn't be a problem with processing multiple scans as larger chunks as long as the chunks are scan aligned.

pnuu commented 4 years ago

It [gradient search] also shouldn't use anything that isn't serializable (something that can't be sent between multiprocess workers).

The gradient search itself doesn't have anything that isn't serializable, but the readers might have, and thus the data can't be serialized. @mraspaud had some ideas on this.

It also means that dask is going to compute every one of these stack arrays because there isn't currently a way to say "based on some math we know stack_2_0 isn't actually needed because we know in_2 doesn't geographically overlap with out_0.

During the Spring 2020 PWC I was looking at gradient search, and made some tests on filtering extra input chunks (chunks that don't cover the output area at all). The gain wasn't huge. I didn't consider the output chunking, as that happens completely within the Cython code. My work is here (no PR, as the gains weren't great). And here's a link to my commentary (in Pytroll Slack) I wrote as I was progressing.

djhoese commented 4 years ago

Ah good point on the reader serialization.

Thanks for the link to the code. I mentioned this during the previously mentioned discussion, but I think your solution can be improved to not use dask for the boundary decisions. Mainly this chunk of code:

https://github.com/pnuu/pyresample/blob/2113827111fed67547a7ac7f054197e86526faf7/pyresample/gradient/__init__.py#L183-L214

(of course github can't render the link so...)

    up_x = x_coords[0, ::x_stride]
    right_x = x_coords[::y_stride, -1]
    down_x = x_coords[-1, ::-x_stride]
    left_x = x_coords[::-y_stride, 0]
    up_y = y_coords[0, ::x_stride]
    right_y = y_coords[::y_stride, -1]
    down_y = y_coords[-1, ::-x_stride]
    left_y = y_coords[::-y_stride, 0]
    res = da.compute(up_x, right_x, down_x, left_x, up_y, right_y, down_y, left_y)

These coords are all simple ~scalars~arrays that don't need to exist beforehand. By creating the dask array with the full resolution initial and then striding it you are still requiring dask to compute the full resolution version of the array (every chunk) and then slice/stride it. I think this is where a majority of the time is probably being spent. If we instead didn't use dask for this and did the calculations ourselves (or add/use a method on the AreaDefinition) we can do:

up_x = np.arange(0, area_def.width, x_stride) * area_def.pixel_size_x + area_def.area_extent[0]
...

This way we don't involve dask (no overhead to send these computations to workers) and we never compute anything we won't be using.

pnuu commented 4 years ago

Oh yeah, that's true! I'll have another look at the code tomorrow and figure out how to get the sides/corners directly without involving dask. One gotcha is that also the source coordinates are in the target coordinate system, so it'll be a bit more complicated operation than just a linear equation you showed.

djhoese commented 4 years ago

Proof of concept for what I'm thinking:

https://gist.github.com/djhoese/655cd7e1f3a26ed972b5e811c5ccb8f4

mraspaud commented 4 years ago

Ok, so I thought more about the parallelization we make and how to avoid computing data when there is no overlap. A solution would be to make the resampling algorithms output arrays of indices and weights instead of the actual resampled data. As the resampling algorithm probably knows that the input and output data isn't overlapping, it could flag it so that in the next step, when we want to actually get the resampled data from the indices and weights from the input data, we can skip the chunks that would be empty.

Now, getting indices and weights is also interesting for resampling multiple datasets at the same time. At the moment for example, the gradient search just reruns the resampling for every dataset, so this could be optimized.

djhoese commented 4 years ago

As the resampling algorithm probably knows that the input and output data isn't overlapping, it could flag it so that in the next step, when we want to actually get the resampled data from the indices and weights from the input data, we can skip the chunks that would be empty.

How would the flagging be done? Would resampling step 1 (indexes and weights) be producing dask arrays? Based on your slack comments, even though this is doing unnecessary computation it should still be fast for gradient search. Right? I'm a little concerned with that being applied to most other algorithms. For example, generating indexes with regular nearest neighbor with kdtrees is the hard part, after that is just indexing the array which is fast.

mraspaud commented 4 years ago

I think if we build our own dask graph/array, we can skip the chunks where all indices are eg NaN. Indeed building the indices array is the slow part, but some of the data generation can be costly do, and being able to discard chunks would help performance I think.

pnuu commented 4 years ago

being able to discard chunks would help performance I think

My draft-PR for discarding unused chunks in gradient search resampling: https://github.com/pytroll/pyresample/pull/282

djhoese commented 4 years ago

I think if we build our own dask graph/array, we can skip the chunks where all indices are eg NaN.

I think since they are indexes they would have to be ints. Probably borrow from the KDTree implementation and make invalid indexes equal to data.size (1 above the last index).

but some of the data generation can be costly do, and being able to discard chunks would help performance I think.

We'd still be checking if they were all fill values though. I could see this requiring custom cython code or maybe depend on functions like nanmax I suppose. At that point we really haven't skipped any chunks have we? Also keep in mind that part of the benefit of cutting out some of the chunks is to reduce the number of tasks that dask has to schedule.

djhoese commented 4 years ago

I was going to post this in slack but realized it should probably go here:

Ok thinking about this more (it keeps distracting me from what I tell myself I should be working on): We want (or at least I do) to represent a stack of chunks and perform an operation along those chunks. The chunks may be different sizes/shapes. Because of this the only easy way to represent this stack that I can think of is with the custom dask graph.

Taking a step back, we want to filter what chunks get put on this stack. Dask doesn't (as far as I know) provide anything that operates on something like the above stack (a series of different shaped things). There may be a way to use dask bags to do this, but I'm not sure it is worth the effort. In the best case we have an AreaDefinition -> AreaDefinition which could do the chunk filtering without dask (no dask arrays needed). In the worst case we have a SwathDefinition with no bounding polygon information so we need to analyze/load the entire geolocation array. In this case any map_blocks-like implementation of this mask generating function would be operating on the chunk-level (reducing MxN arrays to a scalar) and that doesn't fit with dask.array's expectations of dealing with things on the numpy shape space; differently sized chunks don't reduce predictably in a way dask.array understands. This means that in the SwathDefinition case you can either:

  1. .compute() the filter array (per-chunk booleans) to then make a list of the chunks to add to the custom graph. After the .compute() this is all in normal python space (non-dask).
  2. Make a dask map_blocks function that takes the mask array (1 boolean scalar at a time) and produces the list of chunk information to process as a list (?). You might have to throw dask delayed in there somewhere to make this actually work since the list of chunk info to process is not a data array but is instead a list of arguments (or maybe chunk indexes). You'd then need to .compute() this result to actually get the full list of chunk information to make your custom graph and continue on processing. This would mean that the list of chunks would be generated in dask space.

I also mentioned on slack that I realized that our resampling problems have evolved from a 2 stage process in numpy-only land (ex. kdtree -> index) to a 3 stage process in dask land (ex. precheck -> kdtree -> index). At least that is one way to justify this complexity.