from deepforest import main
from deepforest import get_data
m = main.deepforest()
m.predict_tile(raster_path=get_data("OSBS_029.png"))
Filename: /Users/benweinstein/Documents/DeepForest/deepforest/main.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
465 429.6 MiB 429.6 MiB 1 @memory_profiler.profile
466 def predict_tile(self,
467 raster_path=None,
468 image=None,
469 patch_size=400,
470 patch_overlap=0.05,
471 iou_threshold=0.15,
472 return_plot=False,
473 mosaic=True,
474 sigma=0.5,
475 thresh=0.001,
476 color=None,
477 thickness=1,
478 crop_model=None,
479 crop_transform=None,
480 crop_augment=False):
481 """For images too large to input into the model, predict_tile cuts the
482 image into overlapping windows, predicts trees on each window and
483 reassambles into a single array.
484
485 Args:
486 raster_path: Path to image on disk
487 image (array): Numpy image array in BGR channel order
488 following openCV convention
489 patch_size: patch size for each window.
490 patch_overlap: patch overlap among windows.
491 iou_threshold: Minimum iou overlap among predictions between
492 windows to be suppressed.
493 Lower values suppress more boxes at edges.
494 mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
495 sigma: variance of Gaussian function used in Gaussian Soft NMS
496 thresh: the score thresh used to filter bboxes after soft-nms performed
497 cropModel: a deepforest.model.CropModel object to predict on crops
498 crop_transform: a torchvision.transforms object to apply to crops
499 crop_augment: a boolean to apply augmentations to crops
500 (deprecated) return_plot: return a plot of the image with predictions overlaid
501 (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
502 (deprecated) thickness: thickness of the rectangle border line in px
503
504 Deprecated Args:
505 - return_plot: Deprecated in favor of using `visualize.plot_results` for
506 rendering predictions. Will be removed in version 2.0.
507 - color: Deprecated bounding box color for visualizations.
508 - thickness: Deprecated bounding box thickness for visualizations.
509
510 Returns:
511 - If `return_plot` is True, returns an image with predictions overlaid (deprecated).
512 - If `mosaic` is True, returns a Pandas DataFrame containing the predicted
513 bounding boxes, scores, and labels.
514 - If `mosaic` is False, returns a list of tuples where each tuple contains
515 a DataFrame of predictions and its corresponding image crop.
516 - Returns None if no predictions are made.
517 """
518 429.7 MiB 0.1 MiB 1 self.model.eval()
519 429.7 MiB 0.0 MiB 1 self.model.nms_thresh = self.config["nms_thresh"]
520
521 # if more than one GPU present, use only a the first available gpu
522 429.7 MiB 0.0 MiB 1 if torch.cuda.device_count() > 1:
523 # Get available gpus and regenerate trainer
524 warnings.warn(
525 "More than one GPU detected. Using only the first GPU for predict_tile.")
526 self.config["devices"] = 1
527 self.create_trainer()
528
529 429.7 MiB 0.0 MiB 1 if (raster_path is None) and (image is None):
530 raise ValueError(
531 "Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
532 )
533
534 429.7 MiB 0.0 MiB 1 if raster_path is None:
535 self.image = image
536 else:
537 431.3 MiB 1.6 MiB 1 self.image = rio.open(raster_path).read()
538 431.4 MiB 0.1 MiB 1 self.image = np.moveaxis(self.image, 0, 2)
539
540 431.5 MiB 0.1 MiB 2 ds = dataset.TileDataset(tile=self.image,
541 431.4 MiB 0.0 MiB 1 patch_overlap=patch_overlap,
542 431.4 MiB 0.0 MiB 1 patch_size=patch_size)
543 418.9 MiB -12.6 MiB 1 batched_results = self.trainer.predict(self, self.predict_dataloader(ds))
544
545 # Flatten list from batched prediction
546 418.9 MiB 0.0 MiB 1 results = []
547 418.9 MiB 0.0 MiB 2 for batch in batched_results:
548 418.9 MiB 0.0 MiB 2 for boxes in batch:
549 418.9 MiB 0.0 MiB 1 results.append(boxes)
550
551 418.9 MiB 0.0 MiB 1 if mosaic:
552 421.0 MiB 2.1 MiB 2 results = predict.mosiac(results,
553 418.9 MiB 0.0 MiB 1 ds.windows,
554 418.9 MiB 0.0 MiB 1 sigma=sigma,
555 418.9 MiB 0.0 MiB 1 thresh=thresh,
556 418.9 MiB 0.0 MiB 1 iou_threshold=iou_threshold)
557 421.1 MiB 0.1 MiB 2 results["label"] = results.label.apply(
558 421.0 MiB 0.0 MiB 1 lambda x: self.numeric_to_label_dict[x])
559 421.1 MiB 0.0 MiB 1 if raster_path:
560 421.1 MiB 0.0 MiB 1 results["image_path"] = os.path.basename(raster_path)
561 421.1 MiB 0.0 MiB 1 if return_plot:
562 # Add deprecated warning
563 warnings.warn("return_plot is deprecated and will be removed in 2.0. "
564 "Use visualize.plot_results on the result instead.")
565 # Draw predictions on BGR
566 if raster_path:
567 tile = rio.open(raster_path).read()
568 else:
569 tile = self.image
570 drawn_plot = tile[:, :, ::-1]
571 drawn_plot = visualize.plot_predictions(tile,
572 results,
573 color=color,
574 thickness=thickness)
575 return drawn_plot
576 else:
577 for df in results:
578 df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
579
580 # TODO this is the 2nd time the crops are generated? Could be more efficient.
581 self.crops = []
582 for window in ds.windows:
583 crop = self.image[window.indices()]
584 self.crops.append(crop)
585
586 return list(zip(results, self.crops))
587
588 421.1 MiB 0.0 MiB 1 if crop_model:
589 # If a crop model is provided, predict on each crop
590 results = predict._predict_crop_model_(crop_model=crop_model,
591 results=results,
592 raster_path=raster_path,
593 trainer=self.trainer,
594 transform=crop_transform,
595 augment=crop_augment)
596 421.1 MiB 0.0 MiB 1 if results.empty:
597 421.1 MiB 0.0 MiB 1 warnings.warn("No predictions made, returning None")
598 421.1 MiB 0.0 MiB 1 return None
599
600 results = utilities.read_file(results)
601
602 if raster_path is None:
603 warnings.warn(
604 "An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
605 )
606 else:
607 results.root_dir = os.path.dirname(raster_path)
608
609 return results
(DeepForest) (base) benweinstein@Bens-MacBook-Pro DeepForest %
Current performance: