google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Add a script to measure speed of basic ops #168

Closed qihqi closed 1 month ago

qihqi commented 1 month ago

Output on v5litepod-8:

Number of devices:  8
bfloat16         Matmul replicated:     438.17 ms        sizes: ('2048.0 MiB', '2048.0 MiB')
bfloat16         Matmul sharded colrow:         108.891 ms       sizes: ('2048.0 MiB', '2048.0 MiB')
bfloat16         matmul sharded rowcol:         76.6386 ms       sizes: ('2048.0 MiB', '2048.0 MiB')
bfloat16         all_gather:    68.3381 ms       sizes: ('2048.0 MiB',)
bfloat16         all_reduce:    8.25386 ms       sizes: ('2048.0 MiB',)
bfloat16         Llama 3xffn shardmap:  0.611614 ms      sizes: ('0.0625 MiB', '86.0 MiB', '86.0 MiB', '86.0 MiB')
bfloat16         Llama 3xffn gspmd:     0.596578 ms      sizes: ('0.0625 MiB', '86.0 MiB', '86.0 MiB', '86.0 MiB')
int8     Matmul replicated:     186.436 ms       sizes: ('1024.0 MiB', '1024.0 MiB')
int8     Matmul sharded colrow:         54.9044 ms       sizes: ('1024.0 MiB', '1024.0 MiB')
int8     matmul sharded rowcol:         38.6539 ms       sizes: ('1024.0 MiB', '1024.0 MiB')
int8     all_gather:    34.4571 ms       sizes: ('1024.0 MiB',)
int8     all_reduce:    4.34715 ms       sizes: ('1024.0 MiB',)
int8     Llama 3xffn shardmap:  0.483992 ms      sizes: ('0.03125 MiB', '43.0 MiB', '43.0 MiB', '43.0 MiB')
int8     Llama 3xffn gspmd:     0.503814 ms      sizes: ('0.03125 MiB', '43.0 MiB', '43.0 MiB', '43.0 MiB')