zhuzilin / ring-flash-attention

Ring attention implementation with flash attention
MIT License
571 stars 45 forks source link

Wait time instrumentation [not intended to be merged] #9

Open andreaskoepf opened 8 months ago

andreaskoepf commented 8 months ago

I tried to measure the time spent in the reqs returned batch_isend_irecv(). Interestingly this time seems to be indepentent of sequence length and in total negligible. Could be that on a single node actual waits happen at a different place or the p2p transfer is so fast that compute is the bottleneck. Maybe needs further investigation.

Measured with git push --set-upstream origin measure_wait_times on 2x A5000:

benchmark seq_len=32:
ring_flash_attn_forward: rank=1, total=4928us, mean=4.00us, min=4.54us, max=8.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4694us, mean=4.00us, min=4.39us, max=11.00us, num_calls=1000
RingComm.wait(): rank=1, total=5055us, mean=5.00us, min=4.74us, max=11.00us, num_calls=1000
RingComm.wait(): rank=0, total=4830us, mean=4.00us, min=4.44us, max=10.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5003us, mean=5.00us, min=4.70us, max=12.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4831us, mean=4.00us, min=4.49us, max=10.00us, num_calls=1000

benchmark seq_len=64:
ring_flash_attn_forward: rank=1, total=5165us, mean=5.00us, min=4.75us, max=13.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4777us, mean=4.00us, min=4.41us, max=10.00us, num_calls=1000
RingComm.wait(): rank=1, total=4904us, mean=4.00us, min=4.58us, max=5.00us, num_calls=1000
RingComm.wait(): rank=0, total=4838us, mean=4.00us, min=4.41us, max=8.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5088us, mean=5.00us, min=4.71us, max=11.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4944us, mean=4.00us, min=4.56us, max=11.00us, num_calls=1000

benchmark seq_len=128:
ring_flash_attn_forward: rank=1, total=5238us, mean=5.00us, min=4.77us, max=20.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4991us, mean=4.00us, min=4.46us, max=13.00us, num_calls=1000
RingComm.wait(): rank=1, total=5048us, mean=5.00us, min=4.59us, max=14.00us, num_calls=1000
RingComm.wait(): rank=0, total=5139us, mean=5.00us, min=4.45us, max=13.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5103us, mean=5.00us, min=4.63us, max=12.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4870us, mean=4.00us, min=4.50us, max=11.00us, num_calls=1000

benchmark seq_len=256:
ring_flash_attn_forward: rank=1, total=5178us, mean=5.00us, min=4.82us, max=19.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4915us, mean=4.00us, min=4.45us, max=10.00us, num_calls=1000
RingComm.wait(): rank=1, total=5099us, mean=5.00us, min=4.77us, max=11.00us, num_calls=1000
RingComm.wait(): rank=0, total=5253us, mean=5.00us, min=4.54us, max=11.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5180us, mean=5.00us, min=4.70us, max=13.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4944us, mean=4.00us, min=4.55us, max=15.00us, num_calls=1000

benchmark seq_len=512:
ring_flash_attn_forward: rank=1, total=5178us, mean=5.00us, min=4.83us, max=12.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=5044us, mean=5.00us, min=4.42us, max=23.00us, num_calls=1000
RingComm.wait(): rank=1, total=5227us, mean=5.00us, min=4.71us, max=15.00us, num_calls=1000
RingComm.wait(): rank=0, total=5210us, mean=5.00us, min=4.51us, max=14.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5286us, mean=5.00us, min=4.78us, max=60.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4932us, mean=4.00us, min=4.50us, max=11.00us, num_calls=1000

benchmark seq_len=1024:
ring_flash_attn_forward: rank=1, total=4969us, mean=4.00us, min=4.62us, max=12.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4930us, mean=4.00us, min=4.40us, max=13.00us, num_calls=1000
RingComm.wait(): rank=1, total=5090us, mean=5.00us, min=4.76us, max=12.00us, num_calls=1000
RingComm.wait(): rank=0, total=4974us, mean=4.00us, min=4.50us, max=8.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5777us, mean=5.00us, min=4.62us, max=29.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=5585us, mean=5.00us, min=4.50us, max=22.00us, num_calls=1000

benchmark seq_len=2048:
ring_flash_attn_forward: rank=0, total=5092us, mean=5.00us, min=4.48us, max=12.00us, num_calls=1000
ring_flash_attn_forward: rank=1, total=5267us, mean=5.00us, min=4.80us, max=15.00us, num_calls=1000
RingComm.wait(): rank=0, total=5126us, mean=5.00us, min=4.44us, max=11.00us, num_calls=1000
RingComm.wait(): rank=1, total=5103us, mean=5.00us, min=4.72us, max=13.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=5129us, mean=5.00us, min=4.69us, max=11.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5072us, mean=5.00us, min=4.70us, max=8.00us, num_calls=1000

benchmark seq_len=4096:
ring_flash_attn_forward: rank=1, total=5046us, mean=5.00us, min=4.66us, max=10.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=4750us, mean=4.00us, min=4.39us, max=10.00us, num_calls=1000
RingComm.wait(): rank=1, total=5542us, mean=5.00us, min=4.78us, max=50.00us, num_calls=1000
RingComm.wait(): rank=0, total=4942us, mean=4.00us, min=4.36us, max=9.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5044us, mean=5.00us, min=4.71us, max=10.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=4884us, mean=4.00us, min=4.56us, max=6.00us, num_calls=1000

benchmark seq_len=8192:
ring_flash_attn_forward: rank=1, total=5605us, mean=5.00us, min=4.57us, max=35.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=5720us, mean=5.00us, min=4.40us, max=56.00us, num_calls=1000
RingComm.wait(): rank=1, total=5580us, mean=5.00us, min=4.63us, max=26.00us, num_calls=1000
RingComm.wait(): rank=0, total=5570us, mean=5.00us, min=4.42us, max=14.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5482us, mean=5.00us, min=4.74us, max=11.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=5254us, mean=5.00us, min=4.61us, max=11.00us, num_calls=1000

benchmark seq_len=16384:
ring_flash_attn_forward: rank=1, total=5379us, mean=5.00us, min=4.71us, max=27.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=5470us, mean=5.00us, min=4.51us, max=33.00us, num_calls=1000
RingComm.wait(): rank=1, total=5730us, mean=5.00us, min=4.84us, max=11.00us, num_calls=1000
RingComm.wait(): rank=0, total=5589us, mean=5.00us, min=4.45us, max=25.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=5868us, mean=5.00us, min=4.74us, max=35.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=5432us, mean=5.00us, min=4.50us, max=32.00us, num_calls=1000

benchmark seq_len=32768:
ring_flash_attn_forward: rank=1, total=7331us, mean=7.00us, min=4.72us, max=38.00us, num_calls=1000
ring_flash_attn_forward: rank=0, total=5824us, mean=5.00us, min=4.49us, max=38.00us, num_calls=1000
RingComm.wait(): rank=1, total=7956us, mean=7.00us, min=4.91us, max=45.00us, num_calls=1000
RingComm.wait(): rank=0, total=9252us, mean=9.00us, min=4.72us, max=40.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=1, total=7487us, mean=7.00us, min=4.88us, max=52.00us, num_calls=1000
zigzag_ring_flash_attn_forward: rank=0, total=5795us, mean=5.00us, min=4.74us, max=26.00us, num_calls=1000