Closed bw4sz closed 1 week ago
This PR closes #830
I've written an out-of-memory dataset class.
class RasterDataset:
"""Dataset for predicting on raster windows
Args:
raster_path (str): Path to raster file
window_size (int): Size of windows to predict on
overlap (float): Overlap between windows as fraction (0-1)
"""
def __init__(self, raster_path, patch_size, patch_overlap):
self.raster = rio.open(raster_path)
self.patch_size = patch_size
self.patch_overlap = patch_overlap
width = self.raster.shape[0]
height = self.raster.shape[1]
# Generate sliding windows over the raster
self.windows = slidingwindow.generateForSize(
height,
width,
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
maxWindowSize=patch_size,
overlapPercent=patch_overlap
)
self.n_windows = len(self.windows)
def __len__(self):
return self.n_windows
def __getitem__(self, idx):
"""Get a window of the raster
Args:
idx (int): Index of window to get
Returns:
crop (torch.Tensor): A tensor of shape (3, height, width)
"""
window = self.windows[idx]
# Read window
window_data = self.raster.read(
window=Window(window.x, window.y, window.w, window.h)
)
window_data = np.moveaxis(window_data, 0, 2)
crop = preprocess.preprocess_image(window_data)
return crop
def close(self):
"""Close the raster dataset"""
self.raster.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
from deepforest import main
from deepforest import get_data
from deepforest.dataset import RasterDataset
m = main.deepforest()
m.load_model("Weecology/deepforest-tree")
import time
start_time = time.time()
m.config["workers"] = 0
m.create_trainer()
result = m.predict_tile(raster_path=get_data("/Users/benweinstein/Downloads/DJI_20211215102728_0027.jpeg"), patch_size=2000, patch_overlap=0, in_memory=False)
elapsed_time = time.time() - start_time
print(f"Time elapsed: {elapsed_time:.2f} seconds not in memory")
start_time = time.time()
result_in_memory = m.predict_tile(raster_path=get_data("/Users/benweinstein/Downloads/DJI_20211215102728_0027.jpeg"), patch_size=2000, patch_overlap=0, in_memory=True)
elapsed_time = time.time() - start_time
print(f"Time elapsed: {elapsed_time:.2f} seconds in memory")
# assert that the results are the same
assert result.equals(result_in_memory)
~4. Confirm out-of-memory takes less memory!~
For a proper .tif with block sizes, we see a memory leak within trainer.predict!
Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
466 643.5 MiB 643.5 MiB 1 @memory_profiler.profile
467 def predict_tile(self,
468 raster_path=None,
469 image=None,
470 patch_size=400,
471 patch_overlap=0.05,
472 iou_threshold=0.15,
473 in_memory=True,
474 return_plot=False,
475 mosaic=True,
476 sigma=0.5,
477 thresh=0.001,
478 color=None,
479 thickness=1,
480 crop_model=None,
481 crop_transform=None,
482 crop_augment=False):
483 """For images too large to input into the model, predict_tile cuts the
484 image into overlapping windows, predicts trees on each window and
485 reassambles into a single array.
486
487 Args:
488 raster_path: Path to image on disk
489 image (array): Numpy image array in BGR channel order
490 following openCV convention
491 patch_size: patch size for each window.
492 patch_overlap: patch overlap among windows.
493 iou_threshold: Minimum iou overlap among predictions between
494 windows to be suppressed.
495 Lower values suppress more boxes at edges.
496 in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
497 mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
498 sigma: variance of Gaussian function used in Gaussian Soft NMS
499 thresh: the score thresh used to filter bboxes after soft-nms performed
500 cropModel: a deepforest.model.CropModel object to predict on crops
501 crop_transform: a torchvision.transforms object to apply to crops
502 crop_augment: a boolean to apply augmentations to crops
503 (deprecated) return_plot: return a plot of the image with predictions overlaid
504 (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
505 (deprecated) thickness: thickness of the rectangle border line in px
506
507 Deprecated Args:
508 - return_plot: Deprecated in favor of using `visualize.plot_results` for
509 rendering predictions. Will be removed in version 2.0.
510 - color: Deprecated bounding box color for visualizations.
511 - thickness: Deprecated bounding box thickness for visualizations.
512
513 Returns:
514 - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
515 - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
516 bounding boxes, scores, and labels.
517 - If `mosaic` is False, returns a list of tuples where each tuple contains
518 a DataFrame of predictions and its corresponding image crop.
519 - Returns None if no predictions are made.
520 """
521 643.5 MiB 0.0 MiB 1 self.model.eval()
522 643.5 MiB 0.0 MiB 1 self.model.nms_thresh = self.config["nms_thresh"]
523
524 # if more than one GPU present, use only a the first available gpu
525 643.5 MiB 0.0 MiB 1 if torch.cuda.device_count() > 1:
526 # Get available gpus and regenerate trainer
527 warnings.warn(
528 "More than one GPU detected. Using only the first GPU for predict_tile.")
529 self.config["devices"] = 1
530 self.create_trainer()
531
532 643.5 MiB 0.0 MiB 1 if (raster_path is None) and (image is None):
533 raise ValueError(
534 "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
535 )
536
537
538 643.5 MiB 0.0 MiB 1 if in_memory:
539 if raster_path is None:
540 image = image
541 else:
542 image = rio.open(raster_path).read()
543 image = np.moveaxis(image, 0, 2)
544
545 ds = dataset.TileDataset(tile=image,
546 patch_overlap=patch_overlap,
547 patch_size=patch_size)
548 else:
549 643.5 MiB 0.0 MiB 1 if raster_path is None:
550 raise ValueError("raster_path is required if in_memory is False")
551
552 # Check for workers config when using out of memory dataset
553 643.5 MiB 0.0 MiB 1 if self.config["workers"] > 0:
554 raise ValueError("workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer().")
555
556 650.3 MiB 6.8 MiB 2 ds = dataset.RasterDataset(raster_path=raster_path,
557 643.5 MiB 0.0 MiB 1 patch_overlap=patch_overlap,
558 643.5 MiB 0.0 MiB 1 patch_size=patch_size)
559
560 1529.3 MiB 879.0 MiB 1 batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
561
562 # Flatten list from batched prediction
563 1529.3 MiB 0.0 MiB 1 results = []
564 1529.3 MiB 0.0 MiB 144 for batch in batched_results:
565 1529.3 MiB 0.0 MiB 286 for boxes in batch:
566 1529.3 MiB 0.0 MiB 143 results.append(boxes)
567
568 1529.3 MiB 0.0 MiB 1 if mosaic:
569 1530.4 MiB 1.0 MiB 2 results = predict.mosiac(results,
570 1529.3 MiB 0.0 MiB 1 ds.windows,
571 1529.3 MiB 0.0 MiB 1 sigma=sigma,
572 1529.3 MiB 0.0 MiB 1 thresh=thresh,
573 1529.3 MiB 0.0 MiB 1 iou_threshold=iou_threshold)
574 1530.4 MiB 0.0 MiB 2 results["label"] = results.label.apply(
575 1530.4 MiB 0.0 MiB 2207 lambda x: self.numeric_to_label_dict[x])
576 1530.4 MiB 0.0 MiB 1 if raster_path:
577 1530.4 MiB 0.0 MiB 1 results["image_path"] = os.path.basename(raster_path)
578 1530.4 MiB 0.0 MiB 1 if return_plot:
579 # Add deprecated warning
580 warnings.warn("return_plot is deprecated and will be removed in 2.0. "
581 "Use visualize.plot_results on the result instead.")
582 # Draw predictions on BGR
583 if raster_path:
584 tile = rio.open(raster_path).read()
585 else:
586 tile = image
587 drawn_plot = tile[:, :, ::-1]
588 drawn_plot = visualize.plot_predictions(tile,
589 results,
590 color=color,
591 thickness=thickness)
592 return drawn_plot
593 else:
594 for df in results:
595 df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
596
597 # TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
598 self.crops = []
599 if raster_path is None:
600 image = image
601 else:
602 image = rio.open(raster_path).read()
603 image = np.moveaxis(image, 0, 2)
604
605 for window in ds.windows:
606 crop = image[window.indices()]
607 self.crops.append(crop)
608
609 return list(zip(results, self.crops))
610
611 1530.4 MiB 0.0 MiB 1 if crop_model:
612 # If a crop model is provided, predict on each crop
613 results = predict._predict_crop_model_(crop_model=crop_model,
614 results=results,
615 raster_path=raster_path,
616 trainer=self.trainer,
617 transform=crop_transform,
618 augment=crop_augment)
619 1530.4 MiB 0.0 MiB 1 if results.empty:
620 warnings.warn("No predictions made, returning None")
621 return None
622
623 1528.7 MiB -1.7 MiB 1 results = utilities.read_file(results)
624
625 1528.7 MiB 0.0 MiB 1 if raster_path is None:
626 warnings.warn(
627 "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
628 )
629 else:
630 1528.7 MiB 0.0 MiB 1 results.root_dir = os.path.dirname(raster_path)
631
632 1528.7 MiB 0.0 MiB 1 return results
Time elapsed: 236.62 seconds not in memory
Predicting DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [03:54<00:00, 0.61it/s]
1240 predictions in overlapping windows, applying non-max supression
1103 predictions kept after non-max suppression
Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
466 1528.7 MiB 1528.7 MiB 1 @memory_profiler.profile
467 def predict_tile(self,
468 raster_path=None,
469 image=None,
470 patch_size=400,
471 patch_overlap=0.05,
472 iou_threshold=0.15,
473 in_memory=True,
474 return_plot=False,
475 mosaic=True,
476 sigma=0.5,
477 thresh=0.001,
478 color=None,
479 thickness=1,
480 crop_model=None,
481 crop_transform=None,
482 crop_augment=False):
483 """For images too large to input into the model, predict_tile cuts the
484 image into overlapping windows, predicts trees on each window and
485 reassambles into a single array.
486
487 Args:
488 raster_path: Path to image on disk
489 image (array): Numpy image array in BGR channel order
490 following openCV convention
491 patch_size: patch size for each window.
492 patch_overlap: patch overlap among windows.
493 iou_threshold: Minimum iou overlap among predictions between
494 windows to be suppressed.
495 Lower values suppress more boxes at edges.
496 in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
497 mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
498 sigma: variance of Gaussian function used in Gaussian Soft NMS
499 thresh: the score thresh used to filter bboxes after soft-nms performed
500 cropModel: a deepforest.model.CropModel object to predict on crops
501 crop_transform: a torchvision.transforms object to apply to crops
502 crop_augment: a boolean to apply augmentations to crops
503 (deprecated) return_plot: return a plot of the image with predictions overlaid
504 (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
505 (deprecated) thickness: thickness of the rectangle border line in px
506
507 Deprecated Args:
508 - return_plot: Deprecated in favor of using `visualize.plot_results` for
509 rendering predictions. Will be removed in version 2.0.
510 - color: Deprecated bounding box color for visualizations.
511 - thickness: Deprecated bounding box thickness for visualizations.
512
513 Returns:
514 - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
515 - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
516 bounding boxes, scores, and labels.
517 - If `mosaic` is False, returns a list of tuples where each tuple contains
518 a DataFrame of predictions and its corresponding image crop.
519 - Returns None if no predictions are made.
520 """
521 1528.7 MiB 0.0 MiB 1 self.model.eval()
522 1528.7 MiB 0.0 MiB 1 self.model.nms_thresh = self.config["nms_thresh"]
523
524 # if more than one GPU present, use only a the first available gpu
525 1528.7 MiB 0.0 MiB 1 if torch.cuda.device_count() > 1:
526 # Get available gpus and regenerate trainer
527 warnings.warn(
528 "More than one GPU detected. Using only the first GPU for predict_tile.")
529 self.config["devices"] = 1
530 self.create_trainer()
531
532 1528.7 MiB 0.0 MiB 1 if (raster_path is None) and (image is None):
533 raise ValueError(
534 "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
535 )
536
537
538 1528.7 MiB 0.0 MiB 1 if in_memory:
539 1528.7 MiB 0.0 MiB 1 if raster_path is None:
540 image = image
541 else:
542 2068.1 MiB 539.4 MiB 1 image = rio.open(raster_path).read()
543 2068.1 MiB 0.0 MiB 1 image = np.moveaxis(image, 0, 2)
544
545 2068.1 MiB 0.0 MiB 2 ds = dataset.TileDataset(tile=image,
546 2068.1 MiB 0.0 MiB 1 patch_overlap=patch_overlap,
547 2068.1 MiB 0.0 MiB 1 patch_size=patch_size)
548 else:
549 if raster_path is None:
550 raise ValueError("raster_path is required if in_memory is False")
551
552 # Check for workers config when using out of memory dataset
553 if self.config["workers"] > 0:
554 raise ValueError("workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer().")
555
556 ds = dataset.RasterDataset(raster_path=raster_path,
557 patch_overlap=patch_overlap,
558 patch_size=patch_size)
559
560 1765.1 MiB -303.0 MiB 1 batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
561
562 # Flatten list from batched prediction
563 1765.1 MiB 0.0 MiB 1 results = []
564 1765.1 MiB 0.0 MiB 144 for batch in batched_results:
565 1765.1 MiB 0.0 MiB 286 for boxes in batch:
566 1765.1 MiB 0.0 MiB 143 results.append(boxes)
567
568 1765.1 MiB 0.0 MiB 1 if mosaic:
569 1765.4 MiB 0.3 MiB 2 results = predict.mosiac(results,
570 1765.1 MiB 0.0 MiB 1 ds.windows,
571 1765.1 MiB 0.0 MiB 1 sigma=sigma,
572 1765.1 MiB 0.0 MiB 1 thresh=thresh,
573 1765.1 MiB 0.0 MiB 1 iou_threshold=iou_threshold)
574 1765.4 MiB 0.0 MiB 2 results["label"] = results.label.apply(
575 1765.4 MiB 0.0 MiB 2207 lambda x: self.numeric_to_label_dict[x])
576 1765.4 MiB 0.0 MiB 1 if raster_path:
577 1765.4 MiB 0.0 MiB 1 results["image_path"] = os.path.basename(raster_path)
578 1765.4 MiB 0.0 MiB 1 if return_plot:
579 # Add deprecated warning
580 warnings.warn("return_plot is deprecated and will be removed in 2.0. "
581 "Use visualize.plot_results on the result instead.")
582 # Draw predictions on BGR
583 if raster_path:
584 tile = rio.open(raster_path).read()
585 else:
586 tile = image
587 drawn_plot = tile[:, :, ::-1]
588 drawn_plot = visualize.plot_predictions(tile,
589 results,
590 color=color,
591 thickness=thickness)
592 return drawn_plot
593 else:
594 for df in results:
595 df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
596
597 # TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
598 self.crops = []
599 if raster_path is None:
600 image = image
601 else:
602 image = rio.open(raster_path).read()
603 image = np.moveaxis(image, 0, 2)
604
605 for window in ds.windows:
606 crop = image[window.indices()]
607 self.crops.append(crop)
608
609 return list(zip(results, self.crops))
610
611 1765.4 MiB 0.0 MiB 1 if crop_model:
612 # If a crop model is provided, predict on each crop
613 results = predict._predict_crop_model_(crop_model=crop_model,
614 results=results,
615 raster_path=raster_path,
616 trainer=self.trainer,
617 transform=crop_transform,
618 augment=crop_augment)
619 1765.4 MiB 0.0 MiB 1 if results.empty:
620 warnings.warn("No predictions made, returning None")
621 return None
622
623 1763.6 MiB -1.9 MiB 1 results = utilities.read_file(results)
624
625 1763.6 MiB 0.0 MiB 1 if raster_path is None:
626 warnings.warn(
627 "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
628 )
629 else:
630 1763.6 MiB 0.0 MiB 1 results.root_dir = os.path.dirname(raster_path)
631
632 1763.6 MiB 0.0 MiB 1 return results
Time elapsed: 251.71 seconds in memory
Backend MacOSX is interactive backend. Turning interactive mode on.
These results show that a within memory dataset releases memory during trainer.predict, whereas the out of memory dataset gains an enormous amount of memory.
Success! Closing the rasterio set on each step helps and only slows it down a bit. | in-memory | total-memory | time |
---|---|---|---|
Yes | 1.39 GB | 205.27 | |
No | 118 MB | 213.80 |
The reason to not just have out-of-memory is 1) It doesn't allow parallelization, rasterio is not threadsafe. 2. It only works on tiled objects, for untiled .JPG objects, its slower, since it loads the entire image every time.
In response to #830 #829, I think it was time to make an out-of-memory prediction tool for when rasters are too large to read. This is for when the rasterio itself is too large for
would yield a memory error. It doesn't (yet) address any memory load within the existing TileDataset class. I imagine that reading blocks from rasterio will be slower than in memory, and that rasterio is not thread safe, so setting workers to 0. Not that we have in general seen much difference between workers = 0 and workers = 1 with multi-processing.