Hi,
I use layernorm and rmsnorm in my training pipeline on an A100 and observed via the pytorch profiler that these functions were quite slow.
E.g. I measured via time.time() just for the rmsnorm:
average time using from flash_attn.ops.triton.layer_norm import RMSNorm: 0.00023508071899414062
average time using a naive RMSNorm implementation via pytorch: 6.4849853515625e-05
The profiler also indicated that there was work done on the CPU, which was somewhat confusing to me.
Hi, I use layernorm and rmsnorm in my training pipeline on an A100 and observed via the pytorch profiler that these functions were quite slow. E.g. I measured via time.time() just for the rmsnorm:
The profiler also indicated that there was work done on the CPU, which was somewhat confusing to me.
Do you know what the issue could be?