Open JiaweiZhuang opened 4 years ago
That dataset is a bit messy, and cleaning it up might take a long time. Since our project focus is the DARTS method, not data cleaning, a quicker way is probably using medical datasets on Kaggle.
Examples:
Will give a try on the ChestX-ray8 from NIH: https://nihcc.app.box.com/v/ChestXray-NIHCC/ It is a giant dataset (~50 GB), which should make NAS more worthwhile. Corresponding paper: https://arxiv.org/abs/1705.02315
The original problem is a "weakly-supervised multi-label image classification and disease localization" Should identify a more tractable, standard classification problem.
Update: The 42GB dataset is actually on Kaggle (hosted via NIH's official account) https://www.kaggle.com/nih-chest-xrays/data
We should focus on the multi-label classification, and forget about localization for a while. Basically, output an 8-element vector of 0 or 1, indicating each disease exists or not.
The metric is the area-under-curve (AUC) for each disease, as from the original paper:
These notebooks perform such multi-label classification and plot ROC curves. They use 14 diseases instead of 8 as in the original paper.
I've finally cleaned-up the chest xray data and uploaded to:
The original data are 42 GBs of PNGs and took ~2 hours to read. I resized them from 1024x1024 to 128x128 (as in other example notebooks), and sampled 40k images from 100k images. The final data is only a 2.5 GB HDF5/NetCDF file and takes 4s to load :) The data processing notebook is at:
The data follows the same format as the cleaned-up galaxy data (https://github.com/capstone2019-neuralsearch/AC297r_2019_NAS/issues/18#issuecomment-549149338), and can be used almost as a drop-in replacement for the current galaxy loader for DARTS.
Two important modifications are needed for this multi-label binary classification problem:
torch.nn.BCELoss()
(binary cross entropy loss). It is NOT CrossEntropyLoss
because this is not a multi-class classification problem as in CIFAR. The output is a vector of 0s and 1s (NOT one-hot encoding! there can be multiple 1s here).torch.sigmoid()
, NOT softmax
, as again this is a multi-output binary classification. See this notebook for a complete training and evaluation pipeline:
The notebook used a very simple 3-layer CNN and got a performance close to the original paper (just for the classification part, not attempting the localization part):
For example the AUC here for Cardiomegaly is 0.82, compared to the original paper with AUC=0.8141.
However, even this tiny model starts to overfit:
With 10+ epochs, the training loss will decease to <0.1 while the validation loss will increase to >0.2. Thus we are in a similar situation as for the galaxy/graphene dataset where the model capability is not a bottleneck, and I don't expect DARTS will help here.
Really, wonderful work thanks @JiaweiZhuang. Given that it is not so promising for DARTS, I will focus my efforts on encouraging sparsity in the learned architecture to make the transfer to final cell more robust.
Will you be able to try running DARTS for this?
Thanks
Thanks for your efforts, I think this is amazing. I agree with Dylan on what to try next, although how would you plan to get a sparse architecture?
Simply adding an L1 loss on the weights might not be enough. PyTorch would automatically differentiate the loss almost everywhere, but not where it matters (in the non-differentiable point). Normally optimization schemes that enforce sparsity apply a soft-thresholding operation to avoid non-differentiability. We can talk about this if you were thinking on something similar. The soft-thresholding update would need to be considered after the gradient.
What was your idea?
Suggested by @jzazo Paper: https://www.nature.com/articles/s41598-018-37638-9 Code: https://github.com/zhaoxuanma/Deeplearning-digital-pathology Dataset: https://www.nature.com/articles/s41598-018-37638-9#data-availability