Closed SnaKey0u0 closed 1 year ago
Dear author, Thank you for your excellent work. I would like to know what is the intention of the mask parameter in Attention_3d, which seems to be created in the init function of Block_3d and it is different from the SAM original code.
if self.shift_size > 0: H, W, D = 32, 32, 32 img_mask = torch.zeros((1, H, W, D, 1)) h_slices = (slice(0, -window_size), slice(-window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -window_size), slice(-window_size, -self.shift_size), slice(-self.shift_size, None)) d_slices = (slice(0, -window_size), slice(-window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: for d in d_slices: img_mask[:, h, w, d, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, window_size)[0] mask_windows = mask_windows.view(-1, window_size * window_size * window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
I speculate that those operations are similar to position encoding, which can generate a [64,512] mask_windows, where the values are 0~26 due to the accumulation of cnt (because of 3x3x3 slices).
However, I don't understand the following code.
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
And why does the attn_mask get added to the attention in the following operations? What is the purpose of this?
Thank you!
Hi, this is adapted from swin-transformer (https://github.com/microsoft/Swin-Transformer) so that the attention is applied inside a local window instead of the whole image to reduce memory usage (which is a major obstacle if we want to apply SAM to a 3D image with larger size).
Dear author, Thank you for your excellent work. I would like to know what is the intention of the mask parameter in Attention_3d, which seems to be created in the init function of Block_3d and it is different from the SAM original code.
I speculate that those operations are similar to position encoding, which can generate a [64,512] mask_windows, where the values are 0~26 due to the accumulation of cnt (because of 3x3x3 slices).
However, I don't understand the following code.
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
And why does the attn_mask get added to the attention in the following operations? What is the purpose of this?
Thank you!