weecology / DeepForest

Python Package for Airborne RGB machine learning
https://deepforest.readthedocs.io/
MIT License
521 stars 176 forks source link

Weighted Random Sampler for Multi-class detection #792

Open bw4sz opened 1 month ago

bw4sz commented 1 month ago

The vast majority of ecological data are imbalanced. We should have some default weighted random sampler options

from torch.utils.data import WeightedRandomSampler
from deepforest import main

m = main.deepforest()

[setup and load data]

dataset = m.train_ds

# Assuming 'dataset' is your PyTorch Dataset object
class_counts = [0] * len(dataset.classes)
for _, label in dataset:
    class_counts[label] += 1

class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[label] for _, label in dataset]
sampler = WeightedRandomSampler(sample_weights, len(dataset))

and would get synced up through load dataset

https://github.com/weecology/DeepForest/blob/e14bc6dccbaa2a276cb69a3b16fd7fb8d3301b61/deepforest/main.py#L289

and into the dataloader object

https://github.com/weecology/DeepForest/blob/e14bc6dccbaa2a276cb69a3b16fd7fb8d3301b61/deepforest/main.py#L270