Closed pommedeterresautee closed 1 year ago
Error says it is related to CUDA addresses and reproduction happens only on cuda graphs only. If mask is hard coded to None, it doesn't happen.
# for 8x16 on t5 small
q.shape=torch.Size([8, 8, 16, 64])
k.shape=torch.Size([8, 8, 16, 64])
v.shape=torch.Size([8, 8, 16, 64])
attention_mask.shape=torch.Size([8, 8, 16, 16])
Attention kernel are not tested with CUDA Graphs for now.
Failing shapes:
All other shapes from unit tests are not crashing on T5. No test fail for bert, including (reintroduced for the occasion) 32x32 shape.
# no CUDA graphs ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized and not cuda_graphs and t5 and (8x16 or 1x32)"
# Bert ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized_cuda_graphs and bert"
# all shapes on t5 but the 2 above are ok
CUDA_LAUNCH_BLOCKING=1 pytest test/test_torchdynamo.py -k "dynamo_optimized_cuda_graphs and t5 and not 8x16 and not 1x32"
Somewhere in the middle of the error trace there is:
heads = 8, size_m = 32, size_n = 32, size_m_rounded = 128, Q = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1615b040>
K = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa512c0>
V = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa572c0>, sm_scale = 1.0
attention_mask = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c162645e0>
TMP = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c162644a0>
output = <[RuntimeError('numel: integer multiplication overflow') raised in repr()] Tensor object at 0x7f0c1aa57d60>, q_batch_stride = 16384, q_head_stride = 64, q_m_stride = 512, q_k_stride = 1, k_batch_stride = 16384
k_head_stride = 64, k_n_stride = 512, k_k_stride = 1, v_batch_stride = 16384, v_head_stride = 64, v_k_stride = 512, v_n_stride = 1, o_batch_stride = 16384, o_head_stride = 64, o_m_stride = 512, o_n_stride = 1
attention_mask_batch_stride = 8, attention_mask_head_stride = 1, attention_mask_m_stride = 256, attention_mask_k_stride = 8, min_clamp_value = -65504.0, NEED_LOAD_MASK_SIZE_N = True, NEED_LOAD_MASK_SIZE_M = True
MASK_BATCH_SIZE = 1, MASK_HEAD_SIZE = 8, MASK_M_SIZE = 32, MASK_K_SIZE = 32, HAS_MASK = True, IS_CAUSAL = False, BLOCK_M = 128, BLOCK_DHEAD = 64, BLOCK_N = 128, grid = (1, 8), num_warps = 4, num_stages = 1
extern_libs = None, stream = 2050343696, warmup = False
Printing values shows that:
q.shape=torch.Size([1, 8, 32, 64]), k.shape=torch.Size([1, 8, 32, 64]), v.shape=torch.Size([1, 8, 32, 64]), output.shape=torch.Size([1, 8, 32, 64])
q.stride()=(16384, 64, 512, 1), k.stride()=(16384, 64, 512, 1), v.stride()=(16384, 64, 512, 1), output.stride()=(16384, 64, 512, 1)
q.dtype=torch.float16, k.dtype=torch.float16, v.dtype=torch.float16, output.dtype=torch.float16
attention_mask.shape=torch.Size([1, 8, 32, 32])
attention_mask.stride()=(8, 1, 256, 8)
If we comment the following line, code do not crash but output is wrong:
# reminder: attention_mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0,
offs_mask += offs_m[:, None] * attention_mask_m_stride
when running tests on
main
, there is a crash. We had other issues on short seqlen and large batch on t5, not sure why...