med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
134 stars 12 forks source link

Why add an avg_pool2d for position embedding? #4

Closed 22TonyFStark closed 1 year ago

22TonyFStark commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this?
https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

peterant330 commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this?

https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi,

The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

22TonyFStark commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this? https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi,

The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

Thanks for your explanation! Have you done experiments on size 256x256? I modify the code and further decrease the memory cost, however, the val metric seems unsatisfying(maybe): [22:34:55.529] epoch: 165/500, iter: 10/11: loss:0.3391487 [22:34:55.698] - Val metrics: 0.6299731 [22:35:03.100] - Val metrics best: 0.62923694 Is these val metrics good or bad? I currently train the code for about 1 day, I wonder if the training is normal. Could you provide a training log for checking reproduction?

peterant330 commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this? https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi, The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

Thanks for your explanation! Have you done experiments on size 256x256? I modify the code and further decrease the memory cost, however, the val metric seems unsatisfying(maybe): [22:34:55.529] epoch: 165/500, iter: 10/11: loss:0.3391487 [22:34:55.698] - Val metrics: 0.6299731 [22:35:03.100] - Val metrics best: 0.62923694 Is these val metrics good or bad? I currently train the code for about 1 day, I wonder if the training is normal. Could you provide a training log for checking reproduction?

Hi,

Are you training with KiTS? If so, your metric seems abnormal. The Val metrics of 165 epochs should be around 0.3~0.4. Can you provide more information on how you modify the code to lower the memory? Such as the crop patch size. As the original SAM is trained on images of size 1024x1024, and we freeze most of its parameters, I would not recommend you lower the resolution of input to the ViT too much.

22TonyFStark commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this? https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi, The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

Thanks for your explanation! Have you done experiments on size 256x256? I modify the code and further decrease the memory cost, however, the val metric seems unsatisfying(maybe): [22:34:55.529] epoch: 165/500, iter: 10/11: loss:0.3391487 [22:34:55.698] - Val metrics: 0.6299731 [22:35:03.100] - Val metrics best: 0.62923694 Is these val metrics good or bad? I currently train the code for about 1 day, I wonder if the training is normal. Could you provide a training log for checking reproduction?

Hi,

Are you training with KiTS? If so, your metric seems abnormal. The Val metrics of 165 epochs should be around 0.3~0.4. Can you provide more information on how you modify the code to lower the memory? Such as the crop patch size. As the original SAM is trained on images of size 1024x1024, and we freeze most of its parameters, I would not recommend you lower the resolution of input to the ViT too much.

Hi, sorry about the missing information ~ I modified the code and trained a model using Colon for 200 epochs, which get the result of the best val metric of 0.57876 (this value is from the Log, the lower the better). Is this val metric looks normal to you? Currently, I meet some model loading problems and I will report the test metric some days later.

I change the following line from 512 to 256, and I do my best to modify other details to suit this resolution, including position embedding, the prompt position, etc. https://github.com/med-air/3DSAM-adapter/blob/dec84a1738a7bbde80954ba079b0f006103d94fd/3DSAM-adapter/train.py#L166

peterant330 commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this? https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi, The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

Thanks for your explanation! Have you done experiments on size 256x256? I modify the code and further decrease the memory cost, however, the val metric seems unsatisfying(maybe): [22:34:55.529] epoch: 165/500, iter: 10/11: loss:0.3391487 [22:34:55.698] - Val metrics: 0.6299731 [22:35:03.100] - Val metrics best: 0.62923694 Is these val metrics good or bad? I currently train the code for about 1 day, I wonder if the training is normal. Could you provide a training log for checking reproduction?

Hi, Are you training with KiTS? If so, your metric seems abnormal. The Val metrics of 165 epochs should be around 0.3~0.4. Can you provide more information on how you modify the code to lower the memory? Such as the crop patch size. As the original SAM is trained on images of size 1024x1024, and we freeze most of its parameters, I would not recommend you lower the resolution of input to the ViT too much.

Hi, sorry about the missing information ~ I modified the code and trained a model using Colon for 200 epochs, which get the result of the best val metric of 0.57876 (this value is from the Log, the lower the better). Is this val metric looks normal to you? Currently, I meet some model loading problems and I will report the test metric some days later.

I change the following line from 512 to 256, and I do my best to modify other details to suit this resolution, including position embedding, the prompt position, etc.

https://github.com/med-air/3DSAM-adapter/blob/dec84a1738a7bbde80954ba079b0f006103d94fd/3DSAM-adapter/train.py#L166

The metric is slightly worse than our reported value (~0.5). But I guess it is reasonable as you lower the resolution and train with fewer epochs.

22TonyFStark commented 1 year ago

Hello, thanks for releasing this inspiring work, I'm curious about how you handle position embedding as the following line, why do you add an avg_pool2d here and how does it work? Do you refer to any papers that do this? https://github.com/med-air/3DSAM-adapter/blob/6ba7e037047df46b6a8d99eb81c82149e43f8c53/3DSAM-adapter/modeling/Med_SAM/image_encoder.py#L150

Hi, The original SAM takes input of size 1024x1024. To save the memory, We use the input of size 512x512 instead. So the original position embedding is larger than the input. So we use an average pooling layer to make the size align.

Thanks for your explanation! Have you done experiments on size 256x256? I modify the code and further decrease the memory cost, however, the val metric seems unsatisfying(maybe): [22:34:55.529] epoch: 165/500, iter: 10/11: loss:0.3391487 [22:34:55.698] - Val metrics: 0.6299731 [22:35:03.100] - Val metrics best: 0.62923694 Is these val metrics good or bad? I currently train the code for about 1 day, I wonder if the training is normal. Could you provide a training log for checking reproduction?

Hi, Are you training with KiTS? If so, your metric seems abnormal. The Val metrics of 165 epochs should be around 0.3~0.4. Can you provide more information on how you modify the code to lower the memory? Such as the crop patch size. As the original SAM is trained on images of size 1024x1024, and we freeze most of its parameters, I would not recommend you lower the resolution of input to the ViT too much.

Hi, sorry about the missing information ~ I modified the code and trained a model using Colon for 200 epochs, which get the result of the best val metric of 0.57876 (this value is from the Log, the lower the better). Is this val metric looks normal to you? Currently, I meet some model loading problems and I will report the test metric some days later. I change the following line from 512 to 256, and I do my best to modify other details to suit this resolution, including position embedding, the prompt position, etc. https://github.com/med-air/3DSAM-adapter/blob/dec84a1738a7bbde80954ba079b0f006103d94fd/3DSAM-adapter/train.py#L166

The metric is slightly worse than our reported value (~0.5). But I guess it is reasonable as you lower the resolution and train with fewer epochs.

Thanks for your explanation ~