Closed susundingkai closed 8 months ago
@susundingkai Would you mind sharing the extracted onnx's mask output? Thank you in advance! My email address is sh239766@dal.ca.
def gen_mask(x_shape,window_size) -> torch.Tensor:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1,x_shape[1], x_shape[2], x_shape[3], 1)) # 1 Z H W 1 [1,8,186,360,1] (2,6,12)
z_slices = (slice(0, -window_size[0]),
slice(-window_size[0], -window_size[0]//2),
slice(-window_size[0]//2, None))
h_slices = (slice(0, -window_size[1]),
slice(-window_size[1], -window_size[1]//2),
slice(-window_size[1]//2, None))
cnt = 0
for z in z_slices:
for h in h_slices:
img_mask[:,z, h, :, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW,window_size, window_size, window_size, 1
mask_windows=mask_windows.contiguous().view(list(mask_windows.shape[:4])+[-1]) #[1,4,31,30,144]
attn_mask = mask_windows.unsqueeze(4) - mask_windows.unsqueeze(5) #[1,4,31,30,144,144]
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask=attn_mask.contiguous().permute(0,3,1,2,4,5).reshape(1,x_shape[3] // window_size[2],(x_shape[1]//window_size[0])*(x_shape[2] // window_size[1]),1,window_size[0] * window_size[1]*window_size[2],window_size[0] * window_size[1]*window_size[2])
return attn_mask
you can generate it using this function.
def gen_mask(x_shape,window_size) -> torch.Tensor: # calculate attention mask for SW-MSA img_mask = torch.zeros((1,x_shape[1], x_shape[2], x_shape[3], 1)) # 1 Z H W 1 [1,8,186,360,1] (2,6,12) z_slices = (slice(0, -window_size[0]), slice(-window_size[0], -window_size[0]//2), slice(-window_size[0]//2, None)) h_slices = (slice(0, -window_size[1]), slice(-window_size[1], -window_size[1]//2), slice(-window_size[1]//2, None)) cnt = 0 for z in z_slices: for h in h_slices: img_mask[:,z, h, :, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, window_size) # nW,window_size, window_size, window_size, 1 mask_windows=mask_windows.contiguous().view(list(mask_windows.shape[:4])+[-1]) #[1,4,31,30,144] attn_mask = mask_windows.unsqueeze(4) - mask_windows.unsqueeze(5) #[1,4,31,30,144,144] attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask=attn_mask.contiguous().permute(0,3,1,2,4,5).reshape(1,x_shape[3] // window_size[2],(x_shape[1]//window_size[0])*(x_shape[2] // window_size[1]),1,window_size[0] * window_size[1]*window_size[2],window_size[0] * window_size[1]*window_size[2]) return attn_mask
you can generate it using this function.
Thanks!
We used video swin transformer to generate 3D attention mask, but we extract the onnx's mask output which is not the same. Can you public the gen_mask function, thanks!