open-forest-observatory / tree-detection-framework

BSD 3-Clause "New" or "Revised" License
3 stars 0 forks source link

Implementing DeepForest with Pytorch Lightning #34

Closed amrithasp02 closed 3 weeks ago

amrithasp02 commented 3 weeks ago

This PR focuses on fitting DeepForest into the framework including - training and predicting.

Major changes to look out for:

  1. preprocessing.py has a CustomDataModule class which inherits from LightningDataModule. I was unable to directly use the DataLoaders created by preprocessing.create_dataloader for lightning since it expects a LightningdataModule.
  2. I created a derived class DeepForestDetector from LightningDetector following our initial logic. Also included DeepForestModule as a new class.
  3. training_step() in DeepForestModule had to be heavily modified to make sure the images and labels are in the format expected by RetinaNet (backbone for deepforest).
  4. predict() method in DeepForestDetector : I decided to directly use deepforest library's method for this to avoid unnecessary replication of code and rewriting.
  5. Included 3 staticmethods (defined by David) to help post-process the prediction outputs to RegionDetectionsSet objects.
  6. Included a notebook deepforest_train_and_predict.ipynb for demonstration.