med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
170 stars 12 forks source link

What is the mask in Attention_3d in image_encoder.py? #26

Closed SnaKey0u0 closed 1 year ago

SnaKey0u0 commented 1 year ago

Dear author, Thank you for your excellent work. I would like to know what is the intention of the mask parameter in Attention_3d, which seems to be created in the init function of Block_3d and it is different from the SAM original code.

        if self.shift_size > 0:
            H, W, D = 32, 32, 32
            img_mask = torch.zeros((1, H, W, D, 1))
            h_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            d_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    for d in d_slices:
                        img_mask[:, h, w, d, :] = cnt
                        cnt += 1
            mask_windows = window_partition(img_mask, window_size)[0]
            mask_windows = mask_windows.view(-1, window_size * window_size * window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

I speculate that those operations are similar to position encoding, which can generate a [64,512] mask_windows, where the values are 0~26 due to the accumulation of cnt (because of 3x3x3 slices).

However, I don't understand the following code. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

And why does the attn_mask get added to the attention in the following operations? What is the purpose of this?

Thank you!

peterant330 commented 1 year ago

Dear author, Thank you for your excellent work. I would like to know what is the intention of the mask parameter in Attention_3d, which seems to be created in the init function of Block_3d and it is different from the SAM original code.

        if self.shift_size > 0:
            H, W, D = 32, 32, 32
            img_mask = torch.zeros((1, H, W, D, 1))
            h_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            d_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    for d in d_slices:
                        img_mask[:, h, w, d, :] = cnt
                        cnt += 1
            mask_windows = window_partition(img_mask, window_size)[0]
            mask_windows = mask_windows.view(-1, window_size * window_size * window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

I speculate that those operations are similar to position encoding, which can generate a [64,512] mask_windows, where the values are 0~26 due to the accumulation of cnt (because of 3x3x3 slices).

However, I don't understand the following code. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

And why does the attn_mask get added to the attention in the following operations? What is the purpose of this?

Thank you!

Hi, this is adapted from swin-transformer (https://github.com/microsoft/Swin-Transformer) so that the attention is applied inside a local window instead of the whole image to reduce memory usage (which is a major obstacle if we want to apply SAM to a 3D image with larger size).