Open bw4sz opened 2 years ago
Balancing approaches are implemented in #22. Both @bw4sz and I have experimented with versions of this both with and without floors/ceilings and not seen any improvement (see also #21). This feels weird, but may be related to Focal Loss having already addressed the class imbalance to the degree possible.
I want to keep coming back to this. It just feels too vital to let go. I've had no success outside of completely balanced data, but i think the sampling process is leading to a ton of inter-run variability and overall just feels like we are denying the model atleast some reasonable prevalence information. Especially when we use site-level metadata, it feels like the overfitting argument here is passed, we are already providing site-specific info.
One of the challenges of this research program is that each decision seems to cascade and effect others. I certainly tested the sampler when we added into species classification. Now, on first try, I cannot find any difference between balanced and unbalanced ('raw') sampling.
This either means that the sampler is not working as intended, or that some other innovation in the mean time has rendered it irrelevant. I will continue to follow this.
Only by the smallest amount does balancing with a ceiling win now. I find the code confusing though, needs more thought.
def train_dataloader(self):
"""Load a training file. The default location is saved during self.setup(), to override this location, set self.train_file before training"""
#get class weights
train = pd.read_csv(self.train_file)
class_weights = train.label.value_counts().to_dict()
data_weights = []
#balance classes
for idx in range(len(self.train_ds)):
path, image, targets = self.train_ds[idx]
label = int(targets.numpy())
class_freq = class_weights[label]
if class_freq > 100:
class_freq = 100
data_weights.append(1/class_freq)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights = data_weights, num_samples=len(self.train_ds))
data_loader = torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.config["batch_size"],
num_workers=self.config["workers"],
sampler=sampler
)
return data_loader
(DeepTreeAttention) [b.weinstein@login3 DeepTreeAttention]$ python
Python 3.8.11 (default, Aug 3 2021, 15:09:35)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from src.data import *
>>> import pandas as pd
>>> config = read_config("config.yml")
>>> data_module = TreeData(csv_file="data/raw/neon_vst_data_2021.csv", regenerate=False, client=None, metadata=True, comet_logger=None)
>>> data_module.setup()
>>> dl = data_module.train_dataloader()
>>> labels = []
>>> for batch in dl:
... paths, inputs, batch_labels = batch
... labels.append(batch_labels.numpy())
...
...
>>>
>>> labels = np.concatenate(labels)
>>> g = pd.Series(labels).value_counts().reset_index(name="taxonID")
>>>
>>> g
index taxonID
0 13 1979
1 19 687
2 1 676
3 8 290
4 7 120
5 14 105
6 15 102
7 23 95
8 20 93
9 21 90
10 2 89
11 9 87
12 17 83
13 11 82
14 10 81
15 0 81
16 4 79
17 22 79
18 16 78
19 18 76
20 3 74
21 5 72
22 6 65
23 12 64
Basically by undersampling the top class, we oversample the bottom. Which is strange because replacement = True is default. https://www.comet.ml/bw4sz/deeptreeattention/eaf46472a6ab4e15b16e5d9286556901/4e121fa6f04e426bad08183ad17f5b94/compare?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=epoch
try balancing without any floor or ceiling resampling value