triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.28k stars 1.63k forks source link

[BUG] error load fp32 value from 2D tensor #4351

Closed horrorChen closed 3 months ago

horrorChen commented 3 months ago

Hello, I am training a model with triton kernel, but it comes NaN in the backward. I export the intermediate data and find that tl.load cannot keep the value as outside the kernel. The original value range from ~1e-9 to ~1e-12, while the loaded numbers differ max to 1e35, with some data just go 0.

I have tried several methods to locate the bug, but I cannot reproduce the problem with torch.rand inputs. It seems like the problem is related to both the shape and value.

Some insights:

I have tested on the nightly release 3.0.0.post20240716052845. To reproduce:

import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr,
               ):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=8)
    return output

size = [2048, 8]
# size = 2048 * 8

loaded_tensors = torch.load('tensor_20240716-231606.pth', map_location=torch.device('cuda:0'))
x = loaded_tensors['scores_grad'].reshape(size)

y = torch.rand(size, dtype=torch.float32, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Or the shape and BLOCK_SIZE can be changed to other numbers.

The data is here: tensor_20240716-231606.pth.zip

horrorChen commented 3 months ago

Also, after I load the data with shape [2048 8], I can either call tl.sum(x) or tensor multiply `x y. But the answer goes wrong when I want to gettl.sum(x * y)`.

horrorChen commented 3 months ago

Main post have been solved. Problem comes from that the backward gradient tensor is not contiguous. Call contiguous() for the input scores_grad and then it can be loaded correctly.

Also, after I load the data with shape [2048 8], I can either call tl.sum(x) or tensor multiply `x y. But the answer goes wrong when I want to gettl.sum(x * y)`.

However, this problem still exists. Both x * y and tl.sum(x) have no precision error compared to torch, but tl.sum(x * y) has a e-16 difference. May open a new issue if cannot solve.

Thx.