Open MinxuanQin opened 2 months ago
The flash attention part is just used for inference, not for distillation. You could feel free to use flash attention for your inference.
Thank you for your reply! So you have distilled a lightweight image encoder with only 6 layers, where the first two layers does not contain attention layers. For the inference, there are no checkpoints with flash attention available; I can distill an image encoder with flash attention and then use it for inference. Do I understand it correctly?
You are correct except one point: You could use our checkpoint to Inference, it supports flash attention.
Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?
I have another question regarding to the distillation process: From utils/prepare_nnunet.py
the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr
and label_name/dataset_name/labelsTr
, but preparelabel.py
and validation.py
only get directory name like label_name
because they use all_dataset_paths = glob(join(args.test_data_path))
not all_dataset_paths = glob(join(args.test_data_path,'*','*'))
.
For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py
, because the script generates duplicated images for each class of a single subject, and preparelabel.py
does not need label?
Thank you very much for your help!
Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?
You are right.
I have another question regarding to the distillation process: From
utils/prepare_nnunet.py
the images and labels from one dataset shall be stored underlabel_name/dataset_name/imagesTr
andlabel_name/dataset_name/labelsTr
, butpreparelabel.py
andvalidation.py
only get directory name likelabel_name
because they useall_dataset_paths = glob(join(args.test_data_path))
notall_dataset_paths = glob(join(args.test_data_path,'*','*'))
.For the distillation process, could I use the image only once, not use data generated from
utils/prepare_uunet.py
, because the script generates duplicated images for each class of a single subject, andpreparelabel.py
does not need label?Thank you very much for your help!
You need to use prepare_uunet.py. The model need to learn from this preprocessed images (crops, registration,etc is necessary).
Got it. Thank you very much!
I have a question regarding to the distillation loss. From the paper, the objective of the layer-wise progressive distillation process is described as
$$Ex (\frac{1}{k} \sum{i=1}^{k} \Vert f{teacher}^{(2i)} (x) - f{student}^{(i)} (x) \Vert )$$
, where $k$ varies from 1 to 6 based on current and total training iterations. From the code distillation.py
, I think the variable curlayer
from the class BaseTrainer
plays the role of $k$, but the loss in this case is defined as loss = self.seg_loss(output[self.curlayer], label[self.curlayer])
, where only L2 norm in the current layer is computed, not from $i=1$ to $i=k$ from my point of view.
In addition, I have read that the iterations is set to 36 for the first laye-wise distillation process from the paper. I would like to know how many iterations were set for the logit-level distillation process. Thank you!
Thank you for sharing the excellent code and checkpoints! I have run the code described in
Readme.md
and would like to determine whether I correctly understood them.The current version of
distillation.py
andvalidate_student.py
use an ImageEncoder with so-called "woatt" attention (window attention), not with 3D sparse flash attention. Thevalidate_student.py
file loads the tiny image encoder (first uploaded checkpoint on Github) as the image encoder; the remaining parts use the fine-tuned teacher model (the second uploaded checkpoint "Finetuned-SAMMED3D"). Does the third checkpoint, "FASTSAM3D," combine the tiny encoder and rest part together?I think those checkpoints do not use
build_sam3D_flash.py
,build_sam3D_dilatedattention.py
, andbuild_3D_decoder.py
. Is it right? Does the checkpoint perform best among all encoder and decoder structure versions? Thank you!