alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
94 stars 16 forks source link

Enable patchwise training and prediction #135

Open davidwilby opened 2 weeks ago

davidwilby commented 2 weeks ago

Hey @tom-andersson - at long last, the long-awaited patchwise training and prediction feature that @nilsleh and @MartinSJRogers have been working on.

This PR adds patching capabilities to DeepSensor during training and inference.

Training

Optional args patching_strategy, patch_size, stride and num_samples_per_date are added to TaskLoader.__call__.

There are two available patching strategies: random_window and sliding_window. The random_window option randomly selects points in the x1 and x2 extent as the centroid of the patch. The number of patches is defined by the num_samples_per_date argument. The sliding_window function starts in the top left of the dataset and convolves from left to right and top to bottom over the data using the user-defined patch_size and stride.

TaskLoader.__call__ now contains additional conditional logic depending upon the patching strategy selected. If no patching strategy is selected, task_generator() runs exactly as before. If random_window (sliding_window) is selected the bounding boxes for the patches are generated using the sample_random_window() (sample_sliding_window()) methods. The bounding boxes are appended to the list bboxes, and passed to task_generator().

Within task_generator() after the sampling strategies are applied, the data is spatially sliced using each bbox in bboxes using the self.spatial_slice_variable() function.

When using a patching strategy, TaskLoader produces a list of tasks per date, rather than an individual task per date. A small change has been made to Task's summarise_str method to avoid an error when printing patched Tasks and to output more meaningful information.

Inference

To run patchwise predictions, a new method has been created in model.py called predict_patch(). This method iterates through and applies the pre-exisiting predict() method to each patched task. The predict() method has not been changed. Within each iteration, prior to running predict() for each patch, the bounding box of each patch is unnormalized, so the X_t of each patch can be passed to the predict() function. The patchwise predictions are stored in the list preds for subsequent stitching.

It is only possible to use the sliding_window patching function during inference, and the stride and patch size are defined when the user generates the test tasks within the task_loader() call. The data_processor must also be passed to predict_patch() method to enable unnormalisation of the coordinates of the bboxes in model.py.

Once the list of patchwise predictions are generated, stitch_clipped_predictions() is used to form a prediction at the original X_t extent. Currently, functionality is provided to subset or clip each patchwise prediction so there is no overlap between adjacent patches and then merge the patches using xr.combine_by_coords(). The modular nature of the code means there is scope for additional stitching strategies to be added after this PR, for example applying a weighting function to overlapping predictions. To ensure the patches are clipped by the correct amount, get_patch_overlap() calculates the overlap between adjacent patches. stitch_clipped_predictions() also contains code to handle patches at the edge or bottom of the dataset, where the overlap may be different.

The output from predict_patch() is the identical DeepSensor object produced in model.predict(), hence DeepSensor’s plotting functionality can subsequently be used in the same way.

Documentation and Testing

New notebook(s) are added illustrating the usage of both patchwise training and prediction.

New tests are added to verify the new behaviour.

Limitations

review-notebook-app[bot] commented 2 weeks ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB