weecology / EvergladesSpeciesModel

A deepforest model for wading bird species prediction.
MIT License
1 stars 1 forks source link

balance dataset #14

Open bw4sz opened 2 years ago

bw4sz commented 2 years ago

try balancing without any floor or ceiling resampling value

ethanwhite commented 2 years ago

Balancing approaches are implemented in #22. Both @bw4sz and I have experimented with versions of this both with and without floors/ceilings and not seen any improvement (see also #21). This feels weird, but may be related to Focal Loss having already addressed the class imbalance to the degree possible.

bw4sz commented 2 years ago

I want to keep coming back to this. It just feels too vital to let go. I've had no success outside of completely balanced data, but i think the sampling process is leading to a ton of inter-run variability and overall just feels like we are denying the model atleast some reasonable prevalence information. Especially when we use site-level metadata, it feels like the overfitting argument here is passed, we are already providing site-specific info.

bw4sz commented 2 years ago

One of the challenges of this research program is that each decision seems to cascade and effect others. I certainly tested the sampler when we added into species classification. Now, on first try, I cannot find any difference between balanced and unbalanced ('raw') sampling.

https://www.comet.ml/bw4sz/deeptreeattention/176395505730431ca567e5ddef84267d/e3b659a81e02473596572d5a815f17c5/compare?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step

This either means that the sampler is not working as intended, or that some other innovation in the mean time has rendered it irrelevant. I will continue to follow this.

bw4sz commented 2 years ago

Only by the smallest amount does balancing with a ceiling win now. I find the code confusing though, needs more thought.

    def train_dataloader(self):
        """Load a training file. The default location is saved during self.setup(), to override this location, set self.train_file before training"""       

        #get class weights
        train = pd.read_csv(self.train_file)
        class_weights = train.label.value_counts().to_dict()     

        data_weights = []
        #balance classes
        for idx in range(len(self.train_ds)):
            path, image, targets = self.train_ds[idx]
            label = int(targets.numpy())
            class_freq = class_weights[label]
            if class_freq > 100:
                class_freq = 100
            data_weights.append(1/class_freq)

        sampler = torch.utils.data.sampler.WeightedRandomSampler(weights = data_weights, num_samples=len(self.train_ds))
        data_loader = torch.utils.data.DataLoader(
            self.train_ds,
            batch_size=self.config["batch_size"],
            num_workers=self.config["workers"],
            sampler=sampler
        )

        return data_loader
(DeepTreeAttention) [b.weinstein@login3 DeepTreeAttention]$ python
Python 3.8.11 (default, Aug  3 2021, 15:09:35)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from src.data import *
>>> import pandas as pd
>>> config = read_config("config.yml")
>>> data_module = TreeData(csv_file="data/raw/neon_vst_data_2021.csv", regenerate=False, client=None, metadata=True, comet_logger=None)
>>> data_module.setup()
>>> dl = data_module.train_dataloader()
>>> labels = []
>>> for batch in dl:
...     paths, inputs, batch_labels = batch
...     labels.append(batch_labels.numpy())
...
...
>>>
>>> labels = np.concatenate(labels)
>>> g = pd.Series(labels).value_counts().reset_index(name="taxonID")
>>>
>>> g
    index  taxonID
0      13     1979
1      19      687
2       1      676
3       8      290
4       7      120
5      14      105
6      15      102
7      23       95
8      20       93
9      21       90
10      2       89
11      9       87
12     17       83
13     11       82
14     10       81
15      0       81
16      4       79
17     22       79
18     16       78
19     18       76
20      3       74
21      5       72
22      6       65
23     12       64

Basically by undersampling the top class, we oversample the bottom. Which is strange because replacement = True is default. https://www.comet.ml/bw4sz/deeptreeattention/eaf46472a6ab4e15b16e5d9286556901/4e121fa6f04e426bad08183ad17f5b94/compare?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=epoch