flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.48k stars 147 forks source link

Improve the precision of the FusedAddRMSNormKernel function #587

Closed Abatom closed 3 weeks ago

Abatom commented 3 weeks ago

When sizeof(T) == 2, the sum of the read input and residual (float x) is split into two parts, high and low 16 bits, and saved to input and residual respectively. Later, input and residual are read out and combined to x, with the aim of improving the precision of the subsequent x * rms_rcp operation.

Increase precision from 1e-2 to 1e-3.

Abatom commented 3 weeks ago
def fused_add_rms_norm(x, residual, weight, eps):
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    x = x + residual.to(torch.float32)
    residual = x.to(orig_dtype)

    variance = x.pow(2).mean(dim=-1, keepdim=True)
    x = x * torch.rsqrt(variance + eps)
    x = x.to(orig_dtype) * weight
    return x, residual

If the function is modified as follows, the output result of the fused_add_rms_norm function will be almost the same as that of FusedAddRMSNormKernel, the precision can reach 1e-20.

def fused_add_rms_norm(x, residual, weight, eps):
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    x = x + residual.to(torch.float32)
    residual = x.to(orig_dtype)

    variance = x.pow(2).mean(dim=-1, keepdim=True)
    x = x * torch.rsqrt(variance + eps) * weight.to(orig_dtype)
    return x.to(orig_dtype), residual
zhyncs commented 3 weeks ago

It's better to add the benchmark result for the new one @Abatom

yzh119 commented 3 weeks ago

@zhyncs we haven't set up a standard benchmark for normalization kernels so I think we can leave it for further work.

One interesting feature to have in flashinfer is to add benchmarking class that returns bandwidth and FLOP utilization like proton. Ideally we can port nvbench to python but I don't have a concrete idea about the amount of work.

Abatom commented 3 weeks ago

@yzh119 The shared memory has already been used in place of global memory, and an global memory read has also been reduced.