Open vgoklani opened 1 week ago
Updated torch to the nightly, same results as above...
What kind of model is this? If it's a memory bound workload like small batch size inference I don't suspect fp8-wo to have a dramatic impact and it might be best to try float8 dynamic quant.
From twitter it seems like you're interested in exploring distributed training in which case you can find how to run training benchmarks https://github.com/pytorch/ao/tree/main/torchao/float8
EDIT: It seems like @vgoklani is interested in both distributed training and inference and they're observing a speedup at small batch size with TE and they also see a speedup with torch._scaled_mm
in which case my best guess is this is just poor autotuning on compile or the right inductor flags werent set
I am running torchao: 0.5 and torch: '2.5.0a0+b465a5843b.nv24.09' on an NVIDIA A6000 ADA card (sm89) which supports FP8.
I ran the generate.py code from the benchmark:
The
float8wo
flag does not appear to be doing anything. Am I missing a step? Thanks!