weecology / DeepForest

Python Package for Airborne RGB machine learning
https://deepforest.readthedocs.io/
MIT License
519 stars 176 forks source link

main.deepforest.predict_file should be able to take in a dataframe, not just a csv file. #797

Open bw4sz opened 1 month ago

bw4sz commented 1 month ago

Looking at the code base.

https://deepforest.readthedocs.io/en/latest/_modules/deepforest/main.html#deepforest.predict_file

it actually looks like dataset.TreeDataset() is the problem here, because utilities.read_file is definitely flexible enough. The docstring of predict_file should be updated and the argument name should be changed (but a deprecation warning added until 2.0), but its the dataset class, which the user really never sees that often that can be updated. I think this would still be considered a patch.

Abhishek-kumar0503 commented 2 weeks ago

After read your statement it seems that utilities.read_file is reads the input CSV file and is flexible and the dataset.TreeDataset is takes the data from the CSV file and organizes it for prediction and here the issue: TreeDataset isn't flexible enough to handle different formats or structures of input data. My solution is to Update TreeDataset to be more flexible, so i make the class of name is TreeDataset:

class TreeDataset:
      def __init__(self, csv_file, root_dir, transforms=None, train=False, column_mapping=None):
          self.root_dir = root_dir
          self.transforms = transforms
          self.train = train
          self.column_mapping = column_mapping or {
              "image_path": "image_path",
              "xmin": "xmin",
              "ymin": "ymin",
              "xmax": "xmax",
              "ymax": "ymax"
          }

          # Load the CSV file with flexible handling
          self.annotations = self._load_annotations(csv_file)

      def _load_annotations(self, csv_file):
          # Read the CSV file using the existing utility
          df = utilities.read_file(csv_file)

          # Rename columns based on column_mapping, if necessary
          df = df.rename(columns=self.column_mapping)

          return df 

and update the predict_file ds = dataset.TreeDataset(csv_file=csv_file, root_dir=root_dir, transforms=None, train=False)

Is this correct?

henrykironde commented 2 weeks ago

@Abhishek-kumar0503, It looks like you're on the right track. Could you submit a PR?.

Abhishek-kumar0503 commented 2 weeks ago

ya sure @henrykironde . i set up this project in my laptop. I have one query that if i change or update some function in main.py file where it will display that changes.