triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.32k stars 1.63k forks source link

Cumulative layer norm with Triton #2360

Open Rodolphe2005 opened 1 year ago

Rodolphe2005 commented 1 year ago

I'm trying to implement a cumulative layer norm using triton. I'm manipulating 3d tensors of shape (batch_size, channels, seq_length) and to begin with I just want to implement the "identity" and store an intermediate variable in the share memory (see the comments inside the code) :

@triton.jit
def cln_kernel(
    input_ptr,
    output_ptr,
    input_stride,
    output_stride,
    num_channels: tl.constexpr,
    seq_length: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0) # 1 pid par élément du batch
    line_offset = pid * input_stride
    batch_element = tl.zeros((num_channels, BLOCK_SIZE), dtype=tl.float32)
    for channel_idx in range(num_channels):
        channel_offset = channel_idx * seq_length
        offsets = tl.arange(0, BLOCK_SIZE)
        mask = offsets < seq_length
        channel_row = tl.load(input_ptr + line_offset + channel_offset + offsets, mask=mask)
        # The following line generates a compilation error
        # How am I supposed to write the channel_row line into the batch_element array ?
        batch_element[channel_idx] = channel_row
        tl.store(output_ptr + line_offset + channel_offset + offsets, channel_row, mask=mask)

def triton_cln(x: torch.Tensor):
    # We need to preallocate the output.
    batch_size = x.size(0)
    num_channels = x.size(1)
    seq_length = x.size(2)
    BLOCK_SIZE = triton.next_power_of_2(seq_length)
    output = torch.empty_like(x)
    assert x.is_cuda and output.is_cuda
    cln_kernel[(batch_size,)](x, output, x.stride(0), output.stride(0), num_channels, seq_length, BLOCK_SIZE=BLOCK_SIZE)
    return output
Rodolphe2005 commented 1 year ago

I also tried to replace the line batch_element[channel_idx] = channel_row by

        for seq_idx in range(seq_length):
            batch_element[channel_idx, seq_idx] = channel_row[seq_idx]

but it didn't help (compilation error again)

jon-chuang commented 1 year ago

Hello @Rodolphe2005, could you please include a copy of the full error log?

jon-chuang commented 1 year ago

Can you write out the full copy of your algorithm in pseudocode? Currently, batch_element serves no purpose in the code snippet.

Rodolphe2005 commented 1 year ago

Traceback (most recent call last):
  File "<string>", line 21, in cln_kernel
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, 'i32', 'i32'), (64, 301, 512), (True, True, (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
    generator.visit(fn.parse())
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
    self.visit_compound_statement(node.body)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 298, in visit_Assign
    _names += [self.visit(target)]
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 617, in visit_Subscript
    assert node.ctx.__class__.__name__ == "Load"
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/rodolphe/disen/disen/triton_temp.py", line 60, in <module>
    output_triton = triton_cln(x)
  File "/home/rodolphe/disen/disen/triton_temp.py", line 51, in triton_cln
    cln_kernel[(batch_size,)](x, output, x.stride(0), output.stride(0), num_channels, seq_length, BLOCK_SIZE=BLOCK_SIZE)
  File "<string>", line 41, in cln_kernel
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 1621, in compile
    next_module = compile(module)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 1550, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
    mod, _ = build_triton_ir(fn, signature, specialization, constants)
  File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 18:8:
def cln_kernel(
    input_ptr,
    output_ptr,
    input_stride,
    output_stride,
    num_channels: tl.constexpr,
    seq_length: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0) # 1 pid par élément du batch
    line_offset = pid * input_stride
    batch_element = tl.zeros((num_channels, seq_length), dtype=tl.float32)
    for channel_idx in range(0, num_channels):
        channel_offset = channel_idx * seq_length
        offsets = tl.arange(0, BLOCK_SIZE)
        mask = offsets < seq_length
        channel_row = tl.load(input_ptr + line_offset + channel_offset + offsets, mask=mask)
        batch_element[channel_idx] = channel_row
        ^
jon-chuang commented 1 year ago

Can you write out the full copy of your algorithm in pseudocode? Currently, batch_element serves no purpose in the code snippet.

If you could provide this that would be great. If your idea is to write out the accumulated statistics, if you are thinking of an RNN use-case, it may have to go into global memory.

I would only store the per-instance (and possibly per-channel) statistics (depends whether you are computing channel-wise or cross-channel layer norm)

I can't imagine why you would want to persistently materialize the rows per channel. Where would they be used for further computation?

Rodolphe2005 commented 1 year ago

Thank you for your interest. Here is what I want to do :

def fast_cln(input, lengths):
    b, c, t = input.shape
    z = temporal_cumulative_mean(input, lengths).view(b, 1, t)
    gamma = temporal_cumulative_mean(input.pow(2), lengths).view(b, 1, t)
    std = (gamma - z.pow(2) + 1e-8).sqrt()
    x = (input - z) / std
    return x

def temporal_cumulative_mean(x: torch.Tensor, lengths: torch.Tensor):
    return torch.cumsum(torch.mean(x, dim=1), dim=1) / lengths

B, C, T = (80, 64, 301)
lengths = torch.arange(1, T+1, device="cuda")
input = torch.randn((B, C, T), device="cuda")
fast_cln(input, lengths)
jon-chuang commented 1 year ago

In this case, T is static. So you should just be able to run layer norm on the entire range of T? Or is the cumulative mean exponential averaging, while T is in fact dynamic?

Rodolphe2005 commented 1 year ago

This normalisation comes from this article https://arxiv.org/pdf/1809.07454.pdf p12 equation (12) This normalization is causal so at any point in time, I can't use values from the future. That's why a standard layer norm on the channel and time dimensions won't make it.

jon-chuang commented 1 year ago

Right. Might I suggest to simply materialize the statistics (mean, var) per channel instead of entire row, and furthermore, store only the latest accumulated (not historical)? This can then be applied to given channel at time T, and reused in computation of T+1.

You would then run a loop over the range of T inside each kernel. In every loop iteration, you will load the un-normalized rows for T, compute the new cumulative stats, apply them to obtain the normalized rows, and store the normalized rows for T.

Further, may I suggest trying to parallelize the kernel across channels via the use of the program_id?

Rodolphe2005 commented 1 year ago

I have an easy way to compute the running mean and variance as follows :

I denote $(f{i, t}){i, t}$ the input with $i$ being the channel index ($1\leq i\leq N$) and $t$ the time index.

I denote $(z_t)_t$ the temporal cumulative mean defined by :

$$zT = \frac{1}{NT}\sum{i, t\leq T} f_{i, t}$$

I also define $\gamma_T$ : $$\gammaT = \frac{1}{NT}\sum{i, t\leq T} f_{i, t}^2$$

I want to compute the cumulative mean variance defined by : $$VT = \frac{1}{NT} \sum{i, t\leq T} (f_{i, t} - z_T)^2$$ Through some computations, one can derive that $$V_T = \gamma_T - z_T^2$$ And, using the formulas of $z_T$ and $\gamma_T$, one can compute $\gamma_T$ and $zT$ by recurence : $$z{T+1} = \frac{T}{T+1}\times z_T + \frac{1}{N(T+1)}\sumi f{i, T+1}$$ $$\gamma_{T+1} = \frac{T}{T+1}\times \gamma_T + \frac{1}{N(T+1)}\sumi f{i, T+1}^2$$

So, I agree with you that we could loop over time and compute only the current mean and variance.

What I don't know is how to do it in triton.

If I understand well, I first create a grid of B blocks (B being the batch_size) because the computations are made for each batch element, it's totally parallelizable through the batch dimension. And then, inside the kernel, I'm manipulating a (C, T) tensor. That's where I'm a bit lost, how do I create and manipulate the $z_T$ and $\gamma_T$ ?! In this kernel, there will be double loops (one over the channel index and one over the time index) ?