NVIDIA / trt-samples-for-hackathon-cn

Simple samples for TensorRT programming
Apache License 2.0
1.47k stars 337 forks source link

tensorrt_llm 中 layer norm 插件的 USE_DIFF_OF_SQUARES 实现可能导致结果出现 nan #88

Open xiatwhu opened 1 year ago

xiatwhu commented 1 year ago

使用 torch 进行简单的验证

import torch

x = torch.full((1, 10240), 256., device='cuda:0', dtype=torch.float32)
x[0, -1] = 255

x_square_mean = (x * x).mean()
x_mean = x.mean()
x_var = x_square_mean - x_mean * x_mean
x_inv_std = torch.rsqrt(x_var + 1e-5)

print(x_var, x_inv_std)

include

include

float malloc_device(std::vector& host) { float dev = nullptr; cudaMalloc(&dev, host.size() sizeof(float)); cudaMemcpy(dev, host.data(), host.size() sizeof(float), cudaMemcpyDefault); return dev; }

int main() { const int size = 10240;

auto host_input = std::vector<float>(size, 256.f);
host_input[size - 1] = 255.f;
auto host_gamma = std::vector<float>(size, 1.f);
auto host_beta = std::vector<float>(size, 0.f);
auto host_output = std::vector<float>(size, 0.f);

float* dev_input = malloc_device(host_input);
float* dev_gamma = malloc_device(host_gamma);
float* dev_beta = malloc_device(host_beta);
float* dev_output = malloc_device(host_output);

tensorrt_llm::kernels::invokeGeneralLayerNorm(
        dev_output, dev_input, dev_gamma, dev_beta, 1e-5f, 1, size);

cudaMemcpy(host_output.data(), dev_output, size * sizeof(float), cudaMemcpyDefault);

std::cout << host_output[0] << std::endl;

return 0;

}


修复方法:
- 方案一:只需要添加一个 variance = max(variance, 0.f) 的保护即可修复该 bug,在代码 tensorrt_llm/kernels/layernormKernels.cu 中 118 行
```cpp
    if (threadIdx.x == 0)
    {
        mean = mean / hidden_dim;
        s_mean = mean;
        if (USE_DIFF_OF_SQUARES)
        {
            variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
            // 此处添加一行 variance = max(variance, 0.f)
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();