pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.45k stars 146 forks source link

How to use float8 with SM89 hardware - i.e. NVIDIA A6000 ADA? #1057

Open vgoklani opened 1 week ago

vgoklani commented 1 week ago

I am running torchao: 0.5 and torch: '2.5.0a0+b465a5843b.nv24.09' on an NVIDIA A6000 ADA card (sm89) which supports FP8.

I ran the generate.py code from the benchmark:

python generate.py --checkpoint_path $CHECKPOINT_PATH --compile --compile_prefill --write_result /root/benchmark_results__baseline.txt

Average tokens/sec: 57.01 Average Bandwidth: 855.74 GB/s Peak Memory Usage: 16.19 GB Model Size: 15.01 GB

20241011143042, tok/s= 57.01, mem/s= 855.74 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path /models/Meta-Llama-3-8B/consolidated.00.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

python generate.py --checkpoint_path $CHECKPOINT_PATH --compile --compile_prefill --quantization float8wo --write_result /root/benchmark_results__float8wo.txt`

Average tokens/sec: 57.00 Average Bandwidth: 855.62 GB/s Peak Memory Usage: 16.19 GB Model Size: 15.01 GB

20241011143316, tok/s= 57.00, mem/s= 855.62 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: float8wo, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path /models/Meta-Llama-3-8B/consolidated.00.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

The float8wo flag does not appear to be doing anything. Am I missing a step? Thanks!

vgoklani commented 1 week ago

Updated torch to the nightly, same results as above...

msaroufim commented 1 day ago

What kind of model is this? If it's a memory bound workload like small batch size inference I don't suspect fp8-wo to have a dramatic impact and it might be best to try float8 dynamic quant.

From twitter it seems like you're interested in exploring distributed training in which case you can find how to run training benchmarks https://github.com/pytorch/ao/tree/main/torchao/float8

EDIT: It seems like @vgoklani is interested in both distributed training and inference and they're observing a speedup at small batch size with TE and they also see a speedup with torch._scaled_mm in which case my best guess is this is just poor autotuning on compile or the right inductor flags werent set