Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.85k stars 1.08k forks source link

Unused cls_token in PatchEmbeddingBlock #3454

Closed night-gale closed 2 years ago

night-gale commented 2 years ago

Describe the bug When I was training the ViT with torch DistributedDataParallel, during backward, torch raises error and reports that

Parameters which did not receive grad for rank 0: vit.patch_embedding.cls_token

which means that the cls_token did not participate in the backward process.

I checked the implementation of ViT and PatchEmbeddingBlock and found the unused cls_token in monai.networks.blocks.patchembedding.py: PatchEmbeddingBlock. image

To Reproduce Steps to reproduce the behavior:

  1. set environment variable in shell TORCH_DISTRIBUTED_DEBUG=INFO
  2. train ViT with torch DistributedDataParallel
Nic-Ma commented 2 years ago

Thanks for raising the issue. Hi @ahatamiz ,

Could you please help double confirm the issue? If we really don't need the cls_token, please remove it.

Thanks in advance.

ahatamiz commented 2 years ago

Hi @night-gale

Thanks for your comments. If you utilize ViT for classification application only, then the classification flag needs to be activated. Doing so will enable the use of cls_token as shown here.

Originally, ViT is used as segmentation backbone for UNETR, hence the application needs to be specificed.

Lastly, cls_token plays an important role in the ViT for classification as it assigns the class type. Hence, removing it will be against the original ViT design. I recommend reading the paper here: https://arxiv.org/pdf/2010.11929.pdf

Thanks

night-gale commented 2 years ago

Hi! @ahatamiz Thanks for your reply!

I understand that the cls_token is an essential component of ViT and can be toggled off by passing classification as False.

However, the redundant cls_token I found is in the PatchEmbeddingBlock. It is not reference in the forward method and cannot be turned off by passing argument.

I currently removed the cls_token in my local copy of Monai and everything now works fine.

It would be great if you could double check the implementation of PatchEmbeddingBlock.

Thanks!

ahatamiz commented 2 years ago

Hi @night-gale

Thanks for pointing out the issue. I see that there is an unused cls_token in here I will address this in a new PR.

Thanks

ahatamiz commented 2 years ago

Hi @Nic-Ma

Thanks for the efforts. I would be appreciate it if this can be addressed in future PRs.

Thanks.

Nic-Ma commented 2 years ago

Hi @ahatamiz ,

OK, sure, I will fix it in a PR soon.

Thanks.