weecology / DeepForest

Python Package for Airborne RGB machine learning
https://deepforest.readthedocs.io/
MIT License
521 stars 176 forks source link

Add an out of memory dataset #831

Closed bw4sz closed 1 week ago

bw4sz commented 3 weeks ago

In response to #830 #829, I think it was time to make an out-of-memory prediction tool for when rasters are too large to read. This is for when the rasterio itself is too large for

import rasterio as rio
image = rio.open(<path>).read()

would yield a memory error. It doesn't (yet) address any memory load within the existing TileDataset class. I imagine that reading blocks from rasterio will be slower than in memory, and that rasterio is not thread safe, so setting workers to 0. Not that we have in general seen much difference between workers = 0 and workers = 1 with multi-processing.

bw4sz commented 2 weeks ago

This PR closes #830

  1. I've written an out-of-memory dataset class.

    class RasterDataset:
    """Dataset for predicting on raster windows
    
    Args:
        raster_path (str): Path to raster file
        window_size (int): Size of windows to predict on
        overlap (float): Overlap between windows as fraction (0-1)
    """
    def __init__(self, raster_path, patch_size, patch_overlap):
    
        self.raster = rio.open(raster_path)
        self.patch_size = patch_size
        self.patch_overlap = patch_overlap
    
        width = self.raster.shape[0]
        height = self.raster.shape[1]
    
        # Generate sliding windows over the raster
        self.windows = slidingwindow.generateForSize(
            height,
            width,
            dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
            maxWindowSize=patch_size,
            overlapPercent=patch_overlap
        )
        self.n_windows = len(self.windows)
    
    def __len__(self):
        return self.n_windows
    
    def __getitem__(self, idx):
        """Get a window of the raster
    
        Args:
            idx (int): Index of window to get
    
        Returns:
            crop (torch.Tensor): A tensor of shape (3, height, width)
        """
        window = self.windows[idx]
    
        # Read window
        window_data = self.raster.read(
            window=Window(window.x, window.y, window.w, window.h)
        )
    
        window_data = np.moveaxis(window_data, 0, 2)
        crop = preprocess.preprocess_image(window_data)
    
        return crop
    
    def close(self):
        """Close the raster dataset"""
        self.raster.close()
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
  2. I've confirmed that it produces the same results as an in-memory dataset class.
from deepforest import main
from deepforest import get_data
from deepforest.dataset import RasterDataset

m = main.deepforest()
m.load_model("Weecology/deepforest-tree")
import time
start_time = time.time()
m.config["workers"] = 0
m.create_trainer()
result = m.predict_tile(raster_path=get_data("/Users/benweinstein/Downloads/DJI_20211215102728_0027.jpeg"), patch_size=2000, patch_overlap=0, in_memory=False)
elapsed_time = time.time() - start_time
print(f"Time elapsed: {elapsed_time:.2f} seconds not in memory")

start_time = time.time()
result_in_memory = m.predict_tile(raster_path=get_data("/Users/benweinstein/Downloads/DJI_20211215102728_0027.jpeg"), patch_size=2000, patch_overlap=0, in_memory=True)
elapsed_time = time.time() - start_time
print(f"Time elapsed: {elapsed_time:.2f} seconds in memory")

# assert that the results are the same
assert result.equals(result_in_memory)
  1. I've added a warning that if rasterio doesn't that .tif is tile the function will not work properly. For an image without internal blocks, it takes more memory, since it reads the entire image every time.
Screenshot 2024-11-04 at 1 57 32 PM

~4. Confirm out-of-memory takes less memory!~

For a proper .tif with block sizes, we see a memory leak within trainer.predict!

Memory usage of out-of-memory dataset


Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   466    643.5 MiB    643.5 MiB           1       @memory_profiler.profile
   467                                             def predict_tile(self,
   468                                                              raster_path=None,
   469                                                              image=None,
   470                                                              patch_size=400,
   471                                                              patch_overlap=0.05,
   472                                                              iou_threshold=0.15,
   473                                                              in_memory=True,
   474                                                              return_plot=False,
   475                                                              mosaic=True,
   476                                                              sigma=0.5,
   477                                                              thresh=0.001,
   478                                                              color=None,
   479                                                              thickness=1,
   480                                                              crop_model=None,
   481                                                              crop_transform=None,
   482                                                              crop_augment=False):
   483                                                 """For images too large to input into the model, predict_tile cuts the
   484                                                 image into overlapping windows, predicts trees on each window and
   485                                                 reassambles into a single array.
   486                                         
   487                                                 Args:
   488                                                     raster_path: Path to image on disk
   489                                                     image (array): Numpy image array in BGR channel order
   490                                                         following openCV convention
   491                                                     patch_size: patch size for each window.
   492                                                     patch_overlap: patch overlap among windows.
   493                                                     iou_threshold: Minimum iou overlap among predictions between
   494                                                         windows to be suppressed.
   495                                                         Lower values suppress more boxes at edges.
   496                                                     in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
   497                                                     mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
   498                                                     sigma: variance of Gaussian function used in Gaussian Soft NMS
   499                                                     thresh: the score thresh used to filter bboxes after soft-nms performed
   500                                                     cropModel: a deepforest.model.CropModel object to predict on crops
   501                                                     crop_transform: a torchvision.transforms object to apply to crops
   502                                                     crop_augment: a boolean to apply augmentations to crops
   503                                                     (deprecated) return_plot: return a plot of the image with predictions overlaid
   504                                                     (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
   505                                                     (deprecated) thickness: thickness of the rectangle border line in px
   506                                         
   507                                                 Deprecated Args:
   508                                                     - return_plot: Deprecated in favor of using `visualize.plot_results` for
   509                                                       rendering predictions. Will be removed in version 2.0.
   510                                                     - color: Deprecated bounding box color for visualizations.
   511                                                     - thickness: Deprecated bounding box thickness for visualizations.
   512                                         
   513                                                 Returns:
   514                                                     - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
   515                                                     - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
   516                                                       bounding boxes, scores, and labels.
   517                                                     - If `mosaic` is False, returns a list of tuples where each tuple contains
   518                                                       a DataFrame of predictions and its corresponding image crop.
   519                                                     - Returns None if no predictions are made.
   520                                                 """
   521    643.5 MiB      0.0 MiB           1           self.model.eval()
   522    643.5 MiB      0.0 MiB           1           self.model.nms_thresh = self.config["nms_thresh"]
   523                                         
   524                                                 # if more than one GPU present, use only a the first available gpu
   525    643.5 MiB      0.0 MiB           1           if torch.cuda.device_count() > 1:
   526                                                     # Get available gpus and regenerate trainer
   527                                                     warnings.warn(
   528                                                         "More than one GPU detected. Using only the first GPU for predict_tile.")
   529                                                     self.config["devices"] = 1
   530                                                     self.create_trainer()
   531                                         
   532    643.5 MiB      0.0 MiB           1           if (raster_path is None) and (image is None):
   533                                                     raise ValueError(
   534                                                         "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
   535                                                     )
   536                                         
   537                                         
   538    643.5 MiB      0.0 MiB           1           if in_memory:
   539                                                     if raster_path is None:
   540                                                         image = image
   541                                                     else:
   542                                                         image = rio.open(raster_path).read()
   543                                                         image = np.moveaxis(image, 0, 2)
   544                                         
   545                                                     ds = dataset.TileDataset(tile=image,
   546                                                                          patch_overlap=patch_overlap,
   547                                                                          patch_size=patch_size)
   548                                                 else:
   549    643.5 MiB      0.0 MiB           1               if raster_path is None:
   550                                                         raise ValueError("raster_path is required if in_memory is False")
   551                                                     
   552                                                     # Check for workers config when using out of memory dataset
   553    643.5 MiB      0.0 MiB           1               if self.config["workers"] > 0:
   554                                                         raise ValueError("workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer().")
   555                                                     
   556    650.3 MiB      6.8 MiB           2               ds = dataset.RasterDataset(raster_path=raster_path,
   557    643.5 MiB      0.0 MiB           1                                           patch_overlap=patch_overlap,
   558    643.5 MiB      0.0 MiB           1                                           patch_size=patch_size)
   559                                         
   560   1529.3 MiB    879.0 MiB           1           batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
   561                                         
   562                                                 # Flatten list from batched prediction
   563   1529.3 MiB      0.0 MiB           1           results = []
   564   1529.3 MiB      0.0 MiB         144           for batch in batched_results:
   565   1529.3 MiB      0.0 MiB         286               for boxes in batch:
   566   1529.3 MiB      0.0 MiB         143                   results.append(boxes)
   567                                         
   568   1529.3 MiB      0.0 MiB           1           if mosaic:
   569   1530.4 MiB      1.0 MiB           2               results = predict.mosiac(results,
   570   1529.3 MiB      0.0 MiB           1                                        ds.windows,
   571   1529.3 MiB      0.0 MiB           1                                        sigma=sigma,
   572   1529.3 MiB      0.0 MiB           1                                        thresh=thresh,
   573   1529.3 MiB      0.0 MiB           1                                        iou_threshold=iou_threshold)
   574   1530.4 MiB      0.0 MiB           2               results["label"] = results.label.apply(
   575   1530.4 MiB      0.0 MiB        2207                   lambda x: self.numeric_to_label_dict[x])
   576   1530.4 MiB      0.0 MiB           1               if raster_path:
   577   1530.4 MiB      0.0 MiB           1                   results["image_path"] = os.path.basename(raster_path)
   578   1530.4 MiB      0.0 MiB           1               if return_plot:
   579                                                         # Add deprecated warning
   580                                                         warnings.warn("return_plot is deprecated and will be removed in 2.0. "
   581                                                                       "Use visualize.plot_results on the result instead.")
   582                                                         # Draw predictions on BGR
   583                                                         if raster_path:
   584                                                             tile = rio.open(raster_path).read()
   585                                                         else:
   586                                                             tile = image
   587                                                         drawn_plot = tile[:, :, ::-1]
   588                                                         drawn_plot = visualize.plot_predictions(tile,
   589                                                                                                 results,
   590                                                                                                 color=color,
   591                                                                                                 thickness=thickness)
   592                                                         return drawn_plot
   593                                                 else:
   594                                                     for df in results:
   595                                                         df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
   596                                         
   597                                                     # TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
   598                                                     self.crops = []
   599                                                     if raster_path is None:
   600                                                         image = image
   601                                                     else:
   602                                                         image = rio.open(raster_path).read()
   603                                                         image = np.moveaxis(image, 0, 2)
   604                                         
   605                                                     for window in ds.windows:
   606                                                         crop = image[window.indices()]
   607                                                         self.crops.append(crop)
   608                                         
   609                                                     return list(zip(results, self.crops))
   610                                         
   611   1530.4 MiB      0.0 MiB           1           if crop_model:
   612                                                     # If a crop model is provided, predict on each crop
   613                                                     results = predict._predict_crop_model_(crop_model=crop_model,
   614                                                                                            results=results,
   615                                                                                            raster_path=raster_path,
   616                                                                                            trainer=self.trainer,
   617                                                                                            transform=crop_transform,
   618                                                                                            augment=crop_augment)
   619   1530.4 MiB      0.0 MiB           1           if results.empty:
   620                                                     warnings.warn("No predictions made, returning None")
   621                                                     return None
   622                                         
   623   1528.7 MiB     -1.7 MiB           1           results = utilities.read_file(results)
   624                                         
   625   1528.7 MiB      0.0 MiB           1           if raster_path is None:
   626                                                     warnings.warn(
   627                                                         "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
   628                                                     )
   629                                                 else:
   630   1528.7 MiB      0.0 MiB           1               results.root_dir = os.path.dirname(raster_path)
   631                                         
   632   1528.7 MiB      0.0 MiB           1           return results

Time elapsed: 236.62 seconds not in memory
Predicting DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [03:54<00:00,  0.61it/s]
1240 predictions in overlapping windows, applying non-max supression
1103 predictions kept after non-max suppression

Memory usage within-memory dataset

Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   466   1528.7 MiB   1528.7 MiB           1       @memory_profiler.profile
   467                                             def predict_tile(self,
   468                                                              raster_path=None,
   469                                                              image=None,
   470                                                              patch_size=400,
   471                                                              patch_overlap=0.05,
   472                                                              iou_threshold=0.15,
   473                                                              in_memory=True,
   474                                                              return_plot=False,
   475                                                              mosaic=True,
   476                                                              sigma=0.5,
   477                                                              thresh=0.001,
   478                                                              color=None,
   479                                                              thickness=1,
   480                                                              crop_model=None,
   481                                                              crop_transform=None,
   482                                                              crop_augment=False):
   483                                                 """For images too large to input into the model, predict_tile cuts the
   484                                                 image into overlapping windows, predicts trees on each window and
   485                                                 reassambles into a single array.
   486                                         
   487                                                 Args:
   488                                                     raster_path: Path to image on disk
   489                                                     image (array): Numpy image array in BGR channel order
   490                                                         following openCV convention
   491                                                     patch_size: patch size for each window.
   492                                                     patch_overlap: patch overlap among windows.
   493                                                     iou_threshold: Minimum iou overlap among predictions between
   494                                                         windows to be suppressed.
   495                                                         Lower values suppress more boxes at edges.
   496                                                     in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
   497                                                     mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
   498                                                     sigma: variance of Gaussian function used in Gaussian Soft NMS
   499                                                     thresh: the score thresh used to filter bboxes after soft-nms performed
   500                                                     cropModel: a deepforest.model.CropModel object to predict on crops
   501                                                     crop_transform: a torchvision.transforms object to apply to crops
   502                                                     crop_augment: a boolean to apply augmentations to crops
   503                                                     (deprecated) return_plot: return a plot of the image with predictions overlaid
   504                                                     (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
   505                                                     (deprecated) thickness: thickness of the rectangle border line in px
   506                                         
   507                                                 Deprecated Args:
   508                                                     - return_plot: Deprecated in favor of using `visualize.plot_results` for
   509                                                       rendering predictions. Will be removed in version 2.0.
   510                                                     - color: Deprecated bounding box color for visualizations.
   511                                                     - thickness: Deprecated bounding box thickness for visualizations.
   512                                         
   513                                                 Returns:
   514                                                     - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
   515                                                     - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
   516                                                       bounding boxes, scores, and labels.
   517                                                     - If `mosaic` is False, returns a list of tuples where each tuple contains
   518                                                       a DataFrame of predictions and its corresponding image crop.
   519                                                     - Returns None if no predictions are made.
   520                                                 """
   521   1528.7 MiB      0.0 MiB           1           self.model.eval()
   522   1528.7 MiB      0.0 MiB           1           self.model.nms_thresh = self.config["nms_thresh"]
   523                                         
   524                                                 # if more than one GPU present, use only a the first available gpu
   525   1528.7 MiB      0.0 MiB           1           if torch.cuda.device_count() > 1:
   526                                                     # Get available gpus and regenerate trainer
   527                                                     warnings.warn(
   528                                                         "More than one GPU detected. Using only the first GPU for predict_tile.")
   529                                                     self.config["devices"] = 1
   530                                                     self.create_trainer()
   531                                         
   532   1528.7 MiB      0.0 MiB           1           if (raster_path is None) and (image is None):
   533                                                     raise ValueError(
   534                                                         "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
   535                                                     )
   536                                         
   537                                         
   538   1528.7 MiB      0.0 MiB           1           if in_memory:
   539   1528.7 MiB      0.0 MiB           1               if raster_path is None:
   540                                                         image = image
   541                                                     else:
   542   2068.1 MiB    539.4 MiB           1                   image = rio.open(raster_path).read()
   543   2068.1 MiB      0.0 MiB           1                   image = np.moveaxis(image, 0, 2)
   544                                         
   545   2068.1 MiB      0.0 MiB           2               ds = dataset.TileDataset(tile=image,
   546   2068.1 MiB      0.0 MiB           1                                    patch_overlap=patch_overlap,
   547   2068.1 MiB      0.0 MiB           1                                    patch_size=patch_size)
   548                                                 else:
   549                                                     if raster_path is None:
   550                                                         raise ValueError("raster_path is required if in_memory is False")
   551                                                     
   552                                                     # Check for workers config when using out of memory dataset
   553                                                     if self.config["workers"] > 0:
   554                                                         raise ValueError("workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer().")
   555                                                     
   556                                                     ds = dataset.RasterDataset(raster_path=raster_path,
   557                                                                                 patch_overlap=patch_overlap,
   558                                                                                 patch_size=patch_size)
   559                                         
   560   1765.1 MiB   -303.0 MiB           1           batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
   561                                         
   562                                                 # Flatten list from batched prediction
   563   1765.1 MiB      0.0 MiB           1           results = []
   564   1765.1 MiB      0.0 MiB         144           for batch in batched_results:
   565   1765.1 MiB      0.0 MiB         286               for boxes in batch:
   566   1765.1 MiB      0.0 MiB         143                   results.append(boxes)
   567                                         
   568   1765.1 MiB      0.0 MiB           1           if mosaic:
   569   1765.4 MiB      0.3 MiB           2               results = predict.mosiac(results,
   570   1765.1 MiB      0.0 MiB           1                                        ds.windows,
   571   1765.1 MiB      0.0 MiB           1                                        sigma=sigma,
   572   1765.1 MiB      0.0 MiB           1                                        thresh=thresh,
   573   1765.1 MiB      0.0 MiB           1                                        iou_threshold=iou_threshold)
   574   1765.4 MiB      0.0 MiB           2               results["label"] = results.label.apply(
   575   1765.4 MiB      0.0 MiB        2207                   lambda x: self.numeric_to_label_dict[x])
   576   1765.4 MiB      0.0 MiB           1               if raster_path:
   577   1765.4 MiB      0.0 MiB           1                   results["image_path"] = os.path.basename(raster_path)
   578   1765.4 MiB      0.0 MiB           1               if return_plot:
   579                                                         # Add deprecated warning
   580                                                         warnings.warn("return_plot is deprecated and will be removed in 2.0. "
   581                                                                       "Use visualize.plot_results on the result instead.")
   582                                                         # Draw predictions on BGR
   583                                                         if raster_path:
   584                                                             tile = rio.open(raster_path).read()
   585                                                         else:
   586                                                             tile = image
   587                                                         drawn_plot = tile[:, :, ::-1]
   588                                                         drawn_plot = visualize.plot_predictions(tile,
   589                                                                                                 results,
   590                                                                                                 color=color,
   591                                                                                                 thickness=thickness)
   592                                                         return drawn_plot
   593                                                 else:
   594                                                     for df in results:
   595                                                         df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
   596                                         
   597                                                     # TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
   598                                                     self.crops = []
   599                                                     if raster_path is None:
   600                                                         image = image
   601                                                     else:
   602                                                         image = rio.open(raster_path).read()
   603                                                         image = np.moveaxis(image, 0, 2)
   604                                         
   605                                                     for window in ds.windows:
   606                                                         crop = image[window.indices()]
   607                                                         self.crops.append(crop)
   608                                         
   609                                                     return list(zip(results, self.crops))
   610                                         
   611   1765.4 MiB      0.0 MiB           1           if crop_model:
   612                                                     # If a crop model is provided, predict on each crop
   613                                                     results = predict._predict_crop_model_(crop_model=crop_model,
   614                                                                                            results=results,
   615                                                                                            raster_path=raster_path,
   616                                                                                            trainer=self.trainer,
   617                                                                                            transform=crop_transform,
   618                                                                                            augment=crop_augment)
   619   1765.4 MiB      0.0 MiB           1           if results.empty:
   620                                                     warnings.warn("No predictions made, returning None")
   621                                                     return None
   622                                         
   623   1763.6 MiB     -1.9 MiB           1           results = utilities.read_file(results)
   624                                         
   625   1763.6 MiB      0.0 MiB           1           if raster_path is None:
   626                                                     warnings.warn(
   627                                                         "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
   628                                                     )
   629                                                 else:
   630   1763.6 MiB      0.0 MiB           1               results.root_dir = os.path.dirname(raster_path)
   631                                         
   632   1763.6 MiB      0.0 MiB           1           return results

Time elapsed: 251.71 seconds in memory
Backend MacOSX is interactive backend. Turning interactive mode on.

These results show that a within memory dataset releases memory during trainer.predict, whereas the out of memory dataset gains an enormous amount of memory.

bw4sz commented 2 weeks ago
Success! Closing the rasterio set on each step helps and only slows it down a bit. in-memory total-memory time
Yes 1.39 GB 205.27
No 118 MB 213.80

The reason to not just have out-of-memory is 1) It doesn't allow parallelization, rasterio is not threadsafe. 2. It only works on tiled objects, for untiled .JPG objects, its slower, since it loads the entire image every time.