farhat-lab / wdnn

Wide and Deep Neural Net for predicting TB resistance from genotypic data
14 stars 13 forks source link

What is the alpha_matrix.csv? #7

Open nottwy opened 4 years ago

nottwy commented 4 years ago

Recently I'm learning your tool and find that I can't understand what the file alpha_matrix.csv is. It's used as target data in keras fit. It's in line 44 of the file evaluation.py wdnn.fit(X_train, alpha_matrix[train], epochs=100, verbose=False, validation_data=[X_val,alpha_matrix[val]])

mchen16 commented 4 years ago

alpha_matrix is calculated based on the proportion of resistant isolates for each drug. In helpers.py, the function "masked_multi_weighted_bce()" takes in this alpha_matrix and reconstructs the labels used in the other machine learning models, so ultimately, the WDNN is trained on a similar label matrix as the other models. However, we adjusted our loss function to upweight the sparser class according to the proportion of resistant isolates in each drug, which is why we constructed the alpha_matrix in this way.