198808xc / Pangu-Weather

An official implementation of Pangu-Weather
989 stars 188 forks source link

About gen_mask function #35

Closed susundingkai closed 8 months ago

susundingkai commented 9 months ago

We used video swin transformer to generate 3D attention mask, but we extract the onnx's mask output which is not the same. Can you public the gen_mask function, thanks!

zhaoshan2 commented 8 months ago

@susundingkai Would you mind sharing the extracted onnx's mask output? Thank you in advance! My email address is sh239766@dal.ca.

susundingkai commented 8 months ago
def gen_mask(x_shape,window_size) -> torch.Tensor:
  # calculate attention mask for SW-MSA
  img_mask = torch.zeros((1,x_shape[1], x_shape[2], x_shape[3], 1))  # 1 Z H W 1 [1,8,186,360,1] (2,6,12)
  z_slices = (slice(0, -window_size[0]),
              slice(-window_size[0], -window_size[0]//2),
              slice(-window_size[0]//2, None))
  h_slices = (slice(0, -window_size[1]),
              slice(-window_size[1], -window_size[1]//2),
              slice(-window_size[1]//2, None))
  cnt = 0
  for z in z_slices:
    for h in h_slices:
          img_mask[:,z, h, :, :] = cnt
          cnt += 1
  mask_windows = window_partition(img_mask, window_size)  # nW,window_size, window_size, window_size, 1
  mask_windows=mask_windows.contiguous().view(list(mask_windows.shape[:4])+[-1]) #[1,4,31,30,144]
  attn_mask = mask_windows.unsqueeze(4) - mask_windows.unsqueeze(5) #[1,4,31,30,144,144]
  attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  attn_mask=attn_mask.contiguous().permute(0,3,1,2,4,5).reshape(1,x_shape[3] // window_size[2],(x_shape[1]//window_size[0])*(x_shape[2] // window_size[1]),1,window_size[0] * window_size[1]*window_size[2],window_size[0] * window_size[1]*window_size[2])
  return attn_mask

you can generate it using this function.

zhaoshan2 commented 8 months ago
def gen_mask(x_shape,window_size) -> torch.Tensor:
  # calculate attention mask for SW-MSA
  img_mask = torch.zeros((1,x_shape[1], x_shape[2], x_shape[3], 1))  # 1 Z H W 1 [1,8,186,360,1] (2,6,12)
  z_slices = (slice(0, -window_size[0]),
              slice(-window_size[0], -window_size[0]//2),
              slice(-window_size[0]//2, None))
  h_slices = (slice(0, -window_size[1]),
              slice(-window_size[1], -window_size[1]//2),
              slice(-window_size[1]//2, None))
  cnt = 0
  for z in z_slices:
    for h in h_slices:
          img_mask[:,z, h, :, :] = cnt
          cnt += 1
  mask_windows = window_partition(img_mask, window_size)  # nW,window_size, window_size, window_size, 1
  mask_windows=mask_windows.contiguous().view(list(mask_windows.shape[:4])+[-1]) #[1,4,31,30,144]
  attn_mask = mask_windows.unsqueeze(4) - mask_windows.unsqueeze(5) #[1,4,31,30,144,144]
  attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  attn_mask=attn_mask.contiguous().permute(0,3,1,2,4,5).reshape(1,x_shape[3] // window_size[2],(x_shape[1]//window_size[0])*(x_shape[2] // window_size[1]),1,window_size[0] * window_size[1]*window_size[2],window_size[0] * window_size[1]*window_size[2])
  return attn_mask

you can generate it using this function.

Thanks!