marianacelyvelasquez / deep-learning-project

Group project for the Deep Learning course in HS 2023: Arrhythmia detection via dialted CNN + SWAG
0 stars 0 forks source link

Stratify datasets #4

Closed pascal-mueller closed 9 months ago

pascal-mueller commented 9 months ago

We are given the CINC2020 dataset consisting of around 80k samples with 24 classes. Each sample can have one or more classes. The CINC2020 data is very imbalanced (page 5 in https://iopscience.iop.org/article/10.1088/1361-6579/abc960/pdf shows the distribution of the classes in the different soruces), meaning that some of the classes show up a lot more than others. This is problematic for several reasons. One simple example to illustrate this is the following:

Imagine a binary classification case where for each sample we either have 0 or 1. Now 99% of all measurements are a 0 (e.g. an AIDS test), then the model could just learn to always predict 0 and would be accurate 99% of the time yet it's utter garbage.

Now since we are trying to reproduce the dilated CNN paper, we hve to use their approach and they do what's called stratification. So your first goal is to read up on that:

Goal 1: Understand http://scikit.ml/stratification.html

Your second goal is to implement that into our codebase. For thhat you first need to understand how exactly it was done in the paper's code. You can download it here: https://physionet.org/static/published-projects/challenge-2020/1.0.2/sources/UMCUVA.zip

The code is, in my humble opinion, chaos but basically you can see how they do stratification in the main-py script in line 50-51:


def cross_validate(n_folds=10):
  root_dir = ARGS.input_dir
  mapping = pd.read_csv('pretraining/label_mapping_ecgnet_eq.csv', delimiter=';')
  class_mapping = mapping[['SNOMED CT Code', 'Training Code']]
  X, y, classes = cinc_utils.get_xy(
    root_dir, max_sample_length=ARGS.max_sample_length, cut_off=ARGS.cut_off_samples,
    class_mapping=class_mapping)
  # convert classes to string types
  classes = [str(c) for c in classes]
  random_state = 821385989 # DO NOT CHANGE THIS VALUE
   stratifier = IterativeStratification(
     n_splits=n_folds, order=ARGS.cv_order, random_state=random_state)
  stratifier = IterativeStratification(
    n_splits=n_folds, order=ARGS.cv_order)

  best_epochs = []
  for k, (train_indexes, test_indexes) in enumerate(stratifier.split(X, y)):
    # get X and y for this fold
    X_train, y_train = X[train_indexes, :], y[train_indexes, :]
    X_test, y_test = X[test_indexes, :], y[test_indexes, :]
    # get datasets
    trainset, validset = get_datasets(root_dir, X_train, y_train, X_test, y_test, classes)
    # get dataloaders
    trainloader = get_dataloader(trainset)
    validloader = get_dataloader(validset)
    # run training procedure
    trainer = get_trainer(classes, fold_idx=k, pos_weights=((calc_pos_weights(trainset) - 1) * .5) + 1)
    best_epoch = trainer.train(ARGS.epochs, trainloader, validloader)
    best_epochs.append((k, best_epoch))
  print('Best epochs:', str(best_epochs).strip('[]'))

They create a stratifier which then is used to choose the correct indices. I have no idea how the stratifier knows how to choose the indices because in my understanding, it'd need to first analyze the data but I didn't look into it at all.

Another thing ot mentione is, that they don't seem to use the pyTorch dataset class properly.

As I understand it, the idea of pyTorch DataSet and PyTorch DataLoader is the following:

The dataset (which can be found under dataloders/dataset.py in oru code) implements a constructor, a len function returning the whole length of the whole dataset and a getitem function. In the constructor of this class I create a list of fielpaths of all the data i.e. a list of 80k length. The getitem(self, idx) then just chooses the idx-th element.

So if I understand everything correclty, the stratifier would provide this idx yet we never actuallyy access it because we use PyTorch's DataLoader class which creates an iterator abstraction for the dataset.

So the DataLoader should actually stratify the data. At least in my understanding.

Now it should be a problem to add stratification here since it should be a common problem: https://www.google.com/search?client=firefox-b-d&q=pytorch+stratify

I would approach this problem like this:

1) Understand how the dialted CNN code implements stratification and how it loads the data. Does it use DataSet? Does it use DataLoader? If so, do they use it correctl or do the do some weird stuff? E.g. people often don't use DataSet correctly because they might use it to refelct their dataset but don't do transofrmation in it yet I beleive data transformation like resmapling should be done in there (but again, I'm no expert).

2) Once you know how data is read, fetched and then stratified think about how to replicate that using a more PyTorch'ian approach.

Edit: Imppementation strategy

  1. Write code to read the data. Cinc challenge provides code for that. Something like read-x-y from cinc_utils, might or might not already be in the utils folder if not check paper code

  2. Use iterativeStratification from the link above to split the data twi times. One time to get train set + reminder, then split reminder into test and validation set. Unless there's a method for a 3-split.

  3. Create PyTorch Datasets and pass the data to it i.e. add the corresponding argument to the constructor. Currently the constructir read filepaths so you can either pass filepaths or data. Just pass data is easier considering the utility functions we have but it's memory intensive :( but for the start should be okay

  4. Use the datasets as we do now

pascal-mueller commented 9 months ago

I turned the dialtedCnn.py file into a class, adding a bit more structure. In the constructor is a TODO for this issue now. :)

Edit: The closing and reopening below was a misclicked...