Closed Abatom closed 3 weeks ago
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x.to(orig_dtype) * weight
return x, residual
If the function is modified as follows, the output result of the fused_add_rms_norm
function will be almost the same as that of FusedAddRMSNormKernel
, the precision can reach 1e-20.
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps) * weight.to(orig_dtype)
return x.to(orig_dtype), residual
It's better to add the benchmark result for the new one @Abatom
@zhyncs we haven't set up a standard benchmark for normalization kernels so I think we can leave it for further work.
One interesting feature to have in flashinfer is to add benchmarking class that returns bandwidth and FLOP utilization like proton. Ideally we can port nvbench to python but I don't have a concrete idea about the amount of work.
@yzh119 The shared memory has already been used in place of global memory, and an global memory read has also been reduced.
When
sizeof(T) == 2
, the sum of the readinput
andresidual
(floatx
) is split into two parts, high and low 16 bits, and saved toinput
andresidual
respectively. Later,input
andresidual
are read out and combined tox
, with the aim of improving the precision of the subsequentx * rms_rcp
operation.Increase precision from 1e-2 to 1e-3.