Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.76k stars 665 forks source link

ViT for 3D brain images #464

Closed Meddebma closed 2 years ago

Meddebma commented 2 years ago

Dear all,

I tried to make a classification task of 3D brain images where I got a dice score >0.8 using DenseNet121, and I wanted to test whether I could improve it using ViT. When I tried the ViT implementation I've got this Error: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not tuple

Could you please take a look on this notebook: https://github.com/Meddebma/pyradiomics/blob/master/Classification_HBI_ViT.ipynb

Thank you very much

Nic-Ma commented 2 years ago

Hi @Meddebma ,

Thanks for your experiments and feedback here. I think the root cause is that the ViT returns a tuple of 2 data instead of only 1 array: https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/vit.py#L112

Hi @finalelement @yiheng-wang-nv @wyli , I plan to slightly enhance the forward() logic of ViT to make it optional to return the hidden_states_out, similar to our DynUNet implementation: https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/dynunet.py#L243 I remember there may be some TorchScript and distributed data-parallel issue in DynUNet with the uncertain return type. Do you have any concerns?

Thanks.

finalelement commented 2 years ago

@Meddebma I went through the notebook. This could also work out if you take outputs[0] for the loss instead of just outputs. since the ViT provides classification and the hidden states both. loss = loss_function(outputs, labels)

In brief, I agree with @Nic-Ma we could make the hidden states optional for easier use.

Meddebma commented 2 years ago

Hi @finalelement, thanks for your help, I implemented your solution, but I've got the same problem in the validation after the second epoch, also the loss looked weired since it did not change. Here is the notebook:

https://github.com/Meddebma/pyradiomics/blob/master/Classification_HBI_ViT.ipynb

thanks

Nic-Ma commented 2 years ago

@finalelement , let me try to fix the TorchScript and distributed data-parallel issues and add support to optionally return the hidden_states_out ASAP.

Thanks.

Nic-Ma commented 2 years ago

Submitted draft PR for this ticket: https://github.com/Project-MONAI/MONAI/pull/3428.

Thanks.

finalelement commented 2 years ago

@Meddebma you ran into that issue, because the model's prediction term also needs to be something like model(val_images)[0]

y_pred = torch.cat([y_pred, **model(val_images)**], dim=0) y = torch.cat([y, val_labels], dim=0)

I would suggest to try fixing the validation, it is not the same way as how the prediction is being extracted from the training. This should hopefully resolve the second epoch validation issue. :)

Meddebma commented 2 years ago

Thank you very much @finalelement, this time it worked.

Unfortunately, something is not working in the classification, I get a accuracy of 0.5 that does not improve over the epochs Is it a problem specific for my 3D Dataset?

https://github.com/Meddebma/pyradiomics/blob/master/Classification_HBI_ViT.ipynb

finalelement commented 2 years ago

@Meddebma Have you tried this exact same data loading and post-processing with a smaller network architecture or any other network, to make sure everything else (all the other components) are working before concluding that there is something incorrect with ViT?

Meddebma commented 2 years ago

Hi @finalelement I have tried exactly the same pipeline with Densenet212 and got an accuracy of 0.83 and AUC of 0.97, I only changed the model. Here is the notebook: https://github.com/Meddebma/pyradiomics/blob/master/Classification_HBI.ipynb

LouiseBloch commented 2 years ago

I have the same problem with my 3D classification task and the VisionTransformer as well and no solution by now. However, if I train the pipeline using the DenseNet121 or the EfficientNetB0, it worked.

Unfortunately, something is not working in the classification, I get a accuracy of 0.5 that does not improve over the epochs Is it a problem specific for my 3D Dataset?

Nic-Ma commented 2 years ago

Hi @finalelement ,

Could you please help share more details to help figure out the issues?

Thanks in advance.

finalelement commented 2 years ago

@LouiseBloch @Meddebma are you guys specifically facing this issue with 3D classification task or is this reproducible for 2D as well?

I just verified the test cases, they look okay to me. We can perhaps resort to CIFAR testing before proceeding with this task for this net to narrow down on the issue?

ahatamiz commented 2 years ago

@Nic-Ma This seems to be a very straightforward issue to deal with. The output of ViT is a tuple. The first need to be used for classification. This was done by design since we make use of intermediate feature maps in the UNETR: https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py#L191

Furthermore, a proper classification loss need to be used.

Maybe I should do a tutorial for ViT classification to further demonstrate the concept.

Nic-Ma commented 2 years ago

Hi @ahatamiz ,

Thanks for your update here. BTW, I am working on a draft PR to enhance the ViT return data, we need to make sure TorchScript and Distributed Data Parallel can work well with it.

Thanks.

ahatamiz commented 2 years ago

Hi @Nic-Ma

Thanks. Currently, I am looking into support of TorchScript and einops. I don't anticipate any issues with those.

Also, please let me know how you plan on changing the ViT outpts. There also needs to be some changes with UNETR if that needs to happen.

Meddebma commented 2 years ago

@Nic-Ma @ahatamiz, thanks a lot for your help. We would be very glad if you provide us a 3D ViT Tutorial, also maybe with an implementation of gradcam to explore visual attention effectively and have explainable classification out of images. Thank you very much.

LouiseBloch commented 2 years ago

Thank you very much. I think I have now solved the problem by decreasing the learning rate. Sorry for not checking this before.

Meddebma commented 2 years ago

@LouiseBloch, can you share with me your code changes? I would be very glad! here is my email address aymenmeddeb@gmail.com Thanks!

bax24 commented 2 years ago

Hello, I am getting the same issue for a binary classification problem did someone figured out how to solve it ?

My loss remains more or less the same throughout epochs..

Here is my github : https://github.com/bax24/Alzheimer-Classification

And the latest graph I got with tensorboard by running ViT on a Batch of 30 MRI

example

Nic-Ma commented 2 years ago

Hi @finalelement ,

Could you please help provide some practice about the question?

Thanks in advance.

bax24 commented 2 years ago

I forgot to mention that the pipeline works with 3D CNN networks which are in my github, basically only the model is changed

bax24 commented 2 years ago

I just trained the ViT on a single batch to see if the model learns anything and none of its loss is changed... So I think the problem might be very simple however I am passing the parameters to the ADAM optimizer and performing loss.backward() as well as optimizer.step() (and again the same pipeline works fine with CNN) but the loss remains constant..

vitbug

I would be very grateful for your help thanks in advance

finalelement commented 2 years ago

@bax24 thanks for sharing results, can you try out your snippet of code with a minor change in ViT by removing the tanh in this line at your local MONAI installation (just edit the file directly and remove tanh), we found that in some instances that this resolves the issue.

https://github.com/Project-MONAI/MONAI/blob/953c7d4e30308cb6225911760772ec76ff71a063/monai/networks/nets/vit.py#L100

bax24 commented 2 years ago

Thanks a lot for your answer yes it does indeed solve the problem the loss is now changing and decreasing!

bax24 commented 2 years ago

Hello,

As mentioned in my previous message the ViT is learning very quickly and quite well on training batch

Capture d’écran 2022-07-31 à 18 57 15 Capture d’écran 2022-07-31 à 18 56 05

Here above are the training loss and accuracy per batch

However, it seems like the ViT is having trouble generalizing to new input data (here under is the testing loss

Capture d’écran 2022-07-31 à 18 57 34

I use the same methods and hyperparameters as both the Unet paper and the vision transformer paper which classified 2D images juste as well as CNN..

Do you have any suggestions concerning this issue ? (my github : https://github.com/bax24/Alzheimer-Classification)

Thank you for your consideration

finalelement commented 2 years ago

Hi @bax24 Thanks for reaching out, I noticed the learning rate is quite high here either 1e-2 or 1e-3 https://github.com/bax24/Alzheimer-Classification/blob/9e8c62588831208305119466a0fca13e8d3b5c89/main.py#L158

I would suggest playing with it and trying out lesser ones for smooth learning in the range of 1e-4 to 1e-5. If possible try conducting a grid-search for best hyper-parameters.

You could also consider using studying the tradeoff between learning rate and batch-size. Every problem/task/dataset is unique and hence the hyper-parameters need to be adjusted towards the task at hand.

bax24 commented 2 years ago

I tried with several learning rates and batch sizes and the same patterns arise each time

test ViT

So the testing loss will go down and stay constant as long as the training loss is constant and not learning much (just like the first three graphs on the image) then if we keep going at one point or another the training loss will start decreasing. As soon as that happens the testing loss jumps and keep rising without stopping (cfr the last two plots or the first one of the second line).

Also, the training loss converges much faster with a smaller learning rate (1e-4 or 1e-5) than a high one I don't know why. Here are the plots for the training of the same runs.

trainViT

The first column is with low lr and a majority of the rest is with a higher lr (I guess the algorithm is escaping a local minima with a higher lr which means it takes more time to converge idk, but it still does if I keep the epochs running)

I have tried a lot of combinations of lr/batch-size and the same patterns appear each time, do you think I should try to modify the ViT parameters ? And if so which ones ? mlp_heads, num_layers, mlp_dim , hidden size, ... ?

Thanks again for your help I appreciate it !

lincong8722 commented 2 years ago

Hello,I tried to use the ViT model to train my 3d data, but whether I removed the last Tanh() or switched the learning rate to between 1e-4 and 1e-5, the validation accuracy has remained at 0.5,and the loss was in the first About 10 epochs will not change, what is the cause of this?

bibhabasumohapatra commented 1 year ago

Hello,I tried to use the ViT model to train my 3d data, but whether I removed the last Tanh() or switched the learning rate to between 1e-4 and 1e-5, the validation accuracy has remained at 0.5,and the loss was in the first About 10 epochs will not change, what is the cause of this?

Same issue... Is there any solution and what is the reason for this?

Cherished-l commented 1 year ago

您好,我尝试使用 ViT 模型来训练我的 3D 数据,但无论是去掉最后一个 Tanh() 还是将学习率切换到 1e-4 和 1e-5 之间,验证准确率一直保持在 0.5,损失在前 10 个时期不会改变,这是什么原因?

同样的问题...是否有任何解决方案,原因是什么?

Hello,I tried to use the ViT model to train my 3d data, but whether I removed the last Tanh() or switched the learning rate to between 1e-4 and 1e-5, the validation accuracy has remained at 0.5,and the loss was in the first About 10 epochs will not change, what is the cause of this?

Same issue... Is there any solution and what is the reason for this?

i have the same issue. did u solve it?

Nic-Ma commented 1 year ago

Hi @finalelement ,

Could you please help share some suggestions for this issue?

Thanks in advance.

Cherished-l commented 1 year ago

I would like to share more information for your help: My task is to use 3d brain t1 for Alzheimer's disease classification, brain age prediction and other issues The data used is the ukb database, with a total of 3w pieces of data. The ukb has processed the data perfectly, so data problems can be ruled out. Using monai's data loading method, the single t1 is cropped from (182, 218, 182) to 182, 182, 182), and then resized to (128, 128, 128)

Using the default monai vit parameters, its loss and acc/mae will not decrease. Note that the number of channels is changed to 1

finalelement commented 1 year ago

@Cherished-l it is hard to say without taking a look at how you have implemented the classification problem. It's possible that there might be something simple missing there. Typically we have only found the tanh or increasing the learning to have tackle this problem.

SholsNig commented 11 months ago

@finalelement, @bax24 I want to ash if you were able to resolve the issue of the loss function not declining as expected. mine is just flat after the 2 epochs. Pls let me know if you were able to fix it and what you did.

KumoLiu commented 10 months ago

Hi @SholsNig, did you try to remove the tanh or decrease the learning rate?

SholsNig commented 10 months ago

Hello @KumoLiu, thank you for your advice. I tried them out, pls see below my observations.

  1. Removing the Tanh - no change, accuracy still around 50%
  2. Changing the learning rate - the average value from the loss function changed but accuracy still at 50%.

Grateful for your feedback on any other suggestion I can tried out. Also, at my end, I want to see if regenerating the 3D image to have the same size will help. The original shape was originally in imagex: 173, imagey: 211 and imagez: 155. I resized it to 208 in all direction.

KumoLiu commented 10 months ago

Hi @SholsNig, from my experience, I would suggest using ResizeWithPadOrCrop to resize images to a reasonable size, e.g., 160 which could avoid disproportionality. BTW, what learning rate did you use?

Hi @finalelement, do you have other suggestions on modifying the hyperparameter to help this issue? Thanks in advance!

SholsNig commented 10 months ago

Thank you @KumoLiu, thank you very much. I used the ResizeWithPadOrCrop in the transformation but did not help. The epoch loss trend looks reasonable but the accuracy just hangs at 50.34%.

On the learning rate, I tried 1e-5, 1e-6, 1e-7.

unnamed unnamed (1)

finalelement commented 10 months ago

@SholsNig have you verified that the outputs for the classification are coming out as valid probabilities on a scale of 0 - 1 before attempting to estimate acuracy?

If yes what activation functions are you using at the end of the ViT backbone to normalize the probability?

SholsNig commented 10 months ago

@finalelement I will check that and revert. Thank you.