jeonsworld / ViT-pytorch

Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
MIT License
1.92k stars 364 forks source link

Why the model gives the same logits for both the classes? #42

Open evapachetti opened 2 years ago

evapachetti commented 2 years ago

Hi, I am using ViT-H_14 pre-trained to perform binary classification of biomedical images. The dataset I have available is very small: I use about 300 images to perform fine tuning and about 30 images for validation. The goal is to classify the images based on the aggressiveness of the tumor represented (Low grade (0) - High grade(1)). However, I noticed that during the prediction, each image is always associated with the label 0, and going to look on the logits, i found that are always produced logits identical pairs (eg [[ 6.877057e-10 -6.877057e-10]]), which are translated into probability pairs of about (0.49,0.51).

Searching the various forums I found many different tips: vary the learning rate (which I decreased to 1e-8), decrease the batch size (from 8 to 2), etc.. Unfortunately none of this works. The last thing I want to try is to increase considerably the number of epochs (at the moment I have trained for only 100 epochs), but before doing so I wanted to see if someone had a more specific suggestion, or even if someone can tell me if this architecture is too much for a dataset so small.

Thanks a lot in advance

Bolin-Chen1 commented 4 months ago

Hello,i met the same problem like you.Have you found out how to solve it? I also tried many ways but not working.

evapachetti commented 4 months ago

Hello, several times have passed, so I am not completely sure what the issue was about. However, I am quite confident that the main reason this happens is that the model is actually not learning anything. In my case, I had a really small dataset on which to fine-tune compared to the model dimension. I resolved this by drastically reducing the model parameters and training the model using my dataset alone. Hope this helps.