weecology / DeepForest

Python Package for Airborne RGB machine learning
https://deepforest.readthedocs.io/
MIT License
521 stars 176 forks source link

predict_tile takes more memory than needed. #830

Closed bw4sz closed 5 days ago

bw4sz commented 3 weeks ago

Current performance:

    import memory_profiler
    @memory_profiler.profile
from deepforest import main
from deepforest import get_data

m = main.deepforest()
m.predict_tile(raster_path=get_data("OSBS_029.png"))
Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   465    429.6 MiB    429.6 MiB           1       @memory_profiler.profile
   466                                             def predict_tile(self,
   467                                                              raster_path=None,
   468                                                              image=None,
   469                                                              patch_size=400,
   470                                                              patch_overlap=0.05,
   471                                                              iou_threshold=0.15,
   472                                                              return_plot=False,
   473                                                              mosaic=True,
   474                                                              sigma=0.5,
   475                                                              thresh=0.001,
   476                                                              color=None,
   477                                                              thickness=1,
   478                                                              crop_model=None,
   479                                                              crop_transform=None,
   480                                                              crop_augment=False):
   481                                                 """For images too large to input into the model, predict_tile cuts the
   482                                                 image into overlapping windows, predicts trees on each window and
   483                                                 reassambles into a single array.
   484                                         
   485                                                 Args:
   486                                                     raster_path: Path to image on disk
   487                                                     image (array): Numpy image array in BGR channel order
   488                                                         following openCV convention
   489                                                     patch_size: patch size for each window.
   490                                                     patch_overlap: patch overlap among windows.
   491                                                     iou_threshold: Minimum iou overlap among predictions between
   492                                                         windows to be suppressed.
   493                                                         Lower values suppress more boxes at edges.
   494                                                     mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
   495                                                     sigma: variance of Gaussian function used in Gaussian Soft NMS
   496                                                     thresh: the score thresh used to filter bboxes after soft-nms performed
   497                                                     cropModel: a deepforest.model.CropModel object to predict on crops
   498                                                     crop_transform: a torchvision.transforms object to apply to crops
   499                                                     crop_augment: a boolean to apply augmentations to crops
   500                                                     (deprecated) return_plot: return a plot of the image with predictions overlaid
   501                                                     (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
   502                                                     (deprecated) thickness: thickness of the rectangle border line in px
   503                                         
   504                                                 Deprecated Args:
   505                                                     - return_plot: Deprecated in favor of using `visualize.plot_results` for
   506                                                       rendering predictions. Will be removed in version 2.0.
   507                                                     - color: Deprecated bounding box color for visualizations.
   508                                                     - thickness: Deprecated bounding box thickness for visualizations.
   509                                         
   510                                                 Returns:
   511                                                     - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
   512                                                     - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
   513                                                       bounding boxes, scores, and labels.
   514                                                     - If `mosaic` is False, returns a list of tuples where each tuple contains
   515                                                       a DataFrame of predictions and its corresponding image crop.
   516                                                     - Returns None if no predictions are made.
   517                                                 """
   518    429.7 MiB      0.1 MiB           1           self.model.eval()
   519    429.7 MiB      0.0 MiB           1           self.model.nms_thresh = self.config["nms_thresh"]
   520                                         
   521                                                 # if more than one GPU present, use only a the first available gpu
   522    429.7 MiB      0.0 MiB           1           if torch.cuda.device_count() > 1:
   523                                                     # Get available gpus and regenerate trainer
   524                                                     warnings.warn(
   525                                                         "More than one GPU detected. Using only the first GPU for predict_tile.")
   526                                                     self.config["devices"] = 1
   527                                                     self.create_trainer()
   528                                         
   529    429.7 MiB      0.0 MiB           1           if (raster_path is None) and (image is None):
   530                                                     raise ValueError(
   531                                                         "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
   532                                                     )
   533                                         
   534    429.7 MiB      0.0 MiB           1           if raster_path is None:
   535                                                     self.image = image
   536                                                 else:
   537    431.3 MiB      1.6 MiB           1               self.image = rio.open(raster_path).read()
   538    431.4 MiB      0.1 MiB           1               self.image = np.moveaxis(self.image, 0, 2)
   539                                         
   540    431.5 MiB      0.1 MiB           2           ds = dataset.TileDataset(tile=self.image,
   541    431.4 MiB      0.0 MiB           1                                    patch_overlap=patch_overlap,
   542    431.4 MiB      0.0 MiB           1                                    patch_size=patch_size)
   543    418.9 MiB    -12.6 MiB           1           batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
   544                                         
   545                                                 # Flatten list from batched prediction
   546    418.9 MiB      0.0 MiB           1           results = []
   547    418.9 MiB      0.0 MiB           2           for batch in batched_results:
   548    418.9 MiB      0.0 MiB           2               for boxes in batch:
   549    418.9 MiB      0.0 MiB           1                   results.append(boxes)
   550                                         
   551    418.9 MiB      0.0 MiB           1           if mosaic:
   552    421.0 MiB      2.1 MiB           2               results = predict.mosiac(results,
   553    418.9 MiB      0.0 MiB           1                                        ds.windows,
   554    418.9 MiB      0.0 MiB           1                                        sigma=sigma,
   555    418.9 MiB      0.0 MiB           1                                        thresh=thresh,
   556    418.9 MiB      0.0 MiB           1                                        iou_threshold=iou_threshold)
   557    421.1 MiB      0.1 MiB           2               results["label"] = results.label.apply(
   558    421.0 MiB      0.0 MiB           1                   lambda x: self.numeric_to_label_dict[x])
   559    421.1 MiB      0.0 MiB           1               if raster_path:
   560    421.1 MiB      0.0 MiB           1                   results["image_path"] = os.path.basename(raster_path)
   561    421.1 MiB      0.0 MiB           1               if return_plot:
   562                                                         # Add deprecated warning
   563                                                         warnings.warn("return_plot is deprecated and will be removed in 2.0. "
   564                                                                       "Use visualize.plot_results on the result instead.")
   565                                                         # Draw predictions on BGR
   566                                                         if raster_path:
   567                                                             tile = rio.open(raster_path).read()
   568                                                         else:
   569                                                             tile = self.image
   570                                                         drawn_plot = tile[:, :, ::-1]
   571                                                         drawn_plot = visualize.plot_predictions(tile,
   572                                                                                                 results,
   573                                                                                                 color=color,
   574                                                                                                 thickness=thickness)
   575                                                         return drawn_plot
   576                                                 else:
   577                                                     for df in results:
   578                                                         df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
   579                                         
   580                                                     # TODO this is the 2nd time the crops are generated? Could be more efficient.
   581                                                     self.crops = []
   582                                                     for window in ds.windows:
   583                                                         crop = self.image[window.indices()]
   584                                                         self.crops.append(crop)
   585                                         
   586                                                     return list(zip(results, self.crops))
   587                                         
   588    421.1 MiB      0.0 MiB           1           if crop_model:
   589                                                     # If a crop model is provided, predict on each crop
   590                                                     results = predict._predict_crop_model_(crop_model=crop_model,
   591                                                                                            results=results,
   592                                                                                            raster_path=raster_path,
   593                                                                                            trainer=self.trainer,
   594                                                                                            transform=crop_transform,
   595                                                                                            augment=crop_augment)
   596    421.1 MiB      0.0 MiB           1           if results.empty:
   597    421.1 MiB      0.0 MiB           1               warnings.warn("No predictions made, returning None")
   598    421.1 MiB      0.0 MiB           1               return None
   599                                         
   600                                                 results = utilities.read_file(results)
   601                                         
   602                                                 if raster_path is None:
   603                                                     warnings.warn(
   604                                                         "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
   605                                                     )
   606                                                 else:
   607                                                     results.root_dir = os.path.dirname(raster_path)
   608                                         
   609                                                 return results

(DeepForest) (base) benweinstein@Bens-MacBook-Pro DeepForest %