NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.41k stars 1.4k forks source link

about mask check 16 #1433

Open lw921014 opened 2 years ago

lw921014 commented 2 years ago

Describe the Bug

For mask op here https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/fmha/src/fmha/mask.h#L54, If we use sm 80 m16n8k16 tensor core, here should be change as following?

col = warp_n * 32 + tid;

Minimal Steps/Code to Reproduce the Bug

**Expected Behavior**

Environment

yjk21 commented 2 years ago

Hello, every warp computes a 16x16 tile, so this column offset should be ok.

yjk21 commented 2 years ago

@lw921014 can we close if that answered your question, or is there something else regarding this we can help with?

lw921014 commented 2 years ago

Hello, every warp computes a 16x16 tile, so this column offset should be ok.

I got it. Thank a lot.

I have another question. According to here, does our current impl only support head size = 64, I mean, how about head size = 32, or 16?