kuleshov-group / caduceus

Bi-Directional Equivariant Long-Range DNA Sequence Modeling
Apache License 2.0
137 stars 14 forks source link

Strange results using mcc on genomic benchmark? #38

Closed leannmlindsey closed 1 month ago

leannmlindsey commented 1 month ago

Thank you again for providing such a complete model with so many options in the code base.

I wanted to calculate the MCC values for the genomic benchmark dataset (even though they are only reported in the paper with accuracy).

To do this, I changed the code in

configs > task > multiclass_classificaiton.yaml

# _target_: tasks.tasks.MultiClass
_name_: multiclass
loss: cross_entropy
metrics:
  - accuracy
  - mcc                          # this is the line that I added
torchmetrics: null

I just added -mcc to the metrics.

This change results in very strange mcc values that do not seem to correspond with the accuracy values. The train/mcc and val/mcc seem to be correct, but the test/mcc is very low (approximately zero which is equivalent to random guessing on a binary dataset).

Using the demo_human_or_worm dataset, where the accuracy is around 96%, it is reporting a test MCC value of 0.009.

Since the test dataset is 50% human and 50% worm (ie not an imbalanced dataset), it is mathematically impossible to obtain an accuracy of 95% and an MCC of 0.009

Perhaps I have misinterpreted how to calculate the MCC using your code. Can you give any advice?

Note: I was able to reproduce this error with a clean downloaded github repo with only the one line change described above.

Logging Output Epoch 7: 100%|██████████| 293/293 [00:18<00:00, 15.53it/s, loss=0.0648, v_num=ed-1, val/accuracy=0.967, val/mcc=0.934, val/loss=0.0928, train/accuracy=0.975, train/mcc=0.949, train/loss=0.067] Epoch 7, global step 2104: 'val/accuracy' reached 0.96720 (best 0.96720), saving model to 'checkpoints/val/accuracy.ckpt' as top 1 Epoch 8: 100%|██████████| 293/293 [00:18<00:00, 15.61it/s, loss=0.0497, v_num=ed-1, val/accuracy=0.967, val/mcc=0.934, val/loss=0.0965, train/accuracy=0.979, train/mcc=0.958, train/loss=0.0565]Epoch 8, global step 2367: 'val/accuracy' was not in top 1
Epoch 9: 100%|██████████| 293/293 [00:18<00:00, 15.64it/s, loss=0.0438, v_num=ed-1, val/accuracy=0.968, val/mcc=0.936, val/loss=0.0915, train/accuracy=0.983, train/mcc=0.965, train/loss=0.0479]Epoch 9, global step 2630: 'val/accuracy' reached 0.96787 (best 0.96787), saving model to 'checkpoints/val/accuracy.ckpt' as top 1 Trainer.fit stopped: max_epochs=10 reached. Epoch 9: 100%|██████████| 293/293 [00:18<00:00, 15.54it/s, loss=0.0438, v_num=ed-1, val/accuracy=0.968, val/mcc=0.936, val/loss=0.0915, train/accuracy=0.983, train/mcc=0.965, train/loss=0.0479] [2024-05-17 07:34:12,783][main][INFO] - Loaded best validation checkpoint from epoch 9 Restoring states from the checkpoint path at checkpoints/val/accuracy.ckpt Using Char-level tokenizer already downloaded train-demo_human_or_worm already downloaded test-demo_human_or_worm [2024-05-17 07:34:26,509][main][INFO] - Custom load_state_dict function is running. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] Loaded model weights from checkpoint at checkpoints/val/accuracy.ckpt Validation DataLoader 0: 100%|██████████| 98/98 [00:02<00:00, 43.37it/s] ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test/accuracy │ 0.9664400219917297 │ │ test/loss │ 0.0914219319820404 │ │ test/mcc │ 0.00956359408424299 │ └───────────────────────────┴───────────────────────────┘

leannmlindsey commented 1 month ago

I have debugged the issue by looking at the y and y_hat values that are being sent into the accuracy and MCC functions. I did this to verify that they were the same, and they were. The issue comes up only in the GB because of the different file format between the GB Benchmark and the NT benchmark. The GB files are stored like this:

(base) [u1323098@notchpeak1:demo_human_or_worm]$ tree -L 2
.
├── test
│   ├── human
│   └── worm
└── train
    ├── human
    └── worm

When you read in the data using the dataloader, you do not randomize the datasets, they are just concat in order so that all of the "human" or 0 labels come first and then all of the "worm" or 1 labels come next.

This does not turn out to affect the MCC for train and val because there must be somewhere else in a pytorch library or your code where they are randomized when each batch is selected. However, when the test dataset is fed in, every batch except 1 (the turnover batch between 0 and 1) has only 1 label, either all 0 or all 1. For each of these batches, the MCC is calculated to be 0. The one batch that has both labels receives an accurate MCC (in the cases that I tested, around 0.88). You must then average the MCC to get the final test/mcc score, and while this may work for accuracy, this does not work for MCC (and likely not F1 either?) because MCC is very sensitive to unbalanced datasets.

In general, I think it is not advisable to average the test scores by batch, but rather calculate the scores on the entire test set. This could also affect your MCC results that you report in the paper, though I think the NT dataset is mostly randomized and balanced, so the change could be minimal. I have not yet checked that.

For my purposes, I am going to just do all of my calculations for test metrics on the raw results of the test set.

Sorry this was so long. Let me know if you have any questions about how I came to this conclusion or if you think this may be incorrect.

yair-schiff commented 1 month ago

Thank you for the detailed analysis. Glad you were able to debug this issue.