arcadelab / FastSAM3D

Code for "FastSAM3D: An Efficient Segment Anything Model for 3D Volumetric Medical Images"
https://arxiv.org/abs/2403.09827
Apache License 2.0
68 stars 4 forks source link

Difference between checkpoint FASTSAM3D and Finetuned-SAMMED3D #6

Open MinxuanQin opened 2 months ago

MinxuanQin commented 2 months ago

Thank you for sharing the excellent code and checkpoints! I have run the code described in Readme.md and would like to determine whether I correctly understood them.

The current version of distillation.py and validate_student.py use an ImageEncoder with so-called "woatt" attention (window attention), not with 3D sparse flash attention. The validate_student.py file loads the tiny image encoder (first uploaded checkpoint on Github) as the image encoder; the remaining parts use the fine-tuned teacher model (the second uploaded checkpoint "Finetuned-SAMMED3D"). Does the third checkpoint, "FASTSAM3D," combine the tiny encoder and rest part together?

I think those checkpoints do not use build_sam3D_flash.py, build_sam3D_dilatedattention.py, and build_3D_decoder.py. Is it right? Does the checkpoint perform best among all encoder and decoder structure versions? Thank you!

skill-diver commented 2 months ago

The flash attention part is just used for inference, not for distillation. You could feel free to use flash attention for your inference.

MinxuanQin commented 2 months ago

Thank you for your reply! So you have distilled a lightweight image encoder with only 6 layers, where the first two layers does not contain attention layers. For the inference, there are no checkpoints with flash attention available; I can distill an image encoder with flash attention and then use it for inference. Do I understand it correctly?

skill-diver commented 2 months ago

You are correct except one point: You could use our checkpoint to Inference, it supports flash attention.

MinxuanQin commented 2 months ago

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

MinxuanQin commented 2 months ago

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

skill-diver commented 2 months ago

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

You are right.

skill-diver commented 2 months ago

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

You need to use prepare_uunet.py. The model need to learn from this preprocessed images (crops, registration,etc is necessary).

MinxuanQin commented 2 months ago

Got it. Thank you very much!

MinxuanQin commented 2 months ago

I have a question regarding to the distillation loss. From the paper, the objective of the layer-wise progressive distillation process is described as

$$Ex (\frac{1}{k} \sum{i=1}^{k} \Vert f{teacher}^{(2i)} (x) - f{student}^{(i)} (x) \Vert )$$

, where $k$ varies from 1 to 6 based on current and total training iterations. From the code distillation.py, I think the variable curlayer from the class BaseTrainer plays the role of $k$, but the loss in this case is defined as loss = self.seg_loss(output[self.curlayer], label[self.curlayer]), where only L2 norm in the current layer is computed, not from $i=1$ to $i=k$ from my point of view.

In addition, I have read that the iterations is set to 36 for the first laye-wise distillation process from the paper. I would like to know how many iterations were set for the logit-level distillation process. Thank you!