Open Rodolphe2005 opened 1 year ago
I also tried to replace the line batch_element[channel_idx] = channel_row
by
for seq_idx in range(seq_length):
batch_element[channel_idx, seq_idx] = channel_row[seq_idx]
but it didn't help (compilation error again)
Hello @Rodolphe2005, could you please include a copy of the full error log?
Can you write out the full copy of your algorithm in pseudocode? Currently, batch_element
serves no purpose in the code snippet.
Traceback (most recent call last):
File "<string>", line 21, in cln_kernel
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, 'i32', 'i32'), (64, 301, 512), (True, True, (True, False), (True, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
generator.visit(fn.parse())
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
self.visit_compound_statement(node.body)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 298, in visit_Assign
_names += [self.visit(target)]
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 617, in visit_Subscript
assert node.ctx.__class__.__name__ == "Load"
AssertionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/rodolphe/.pyenv/versions/3.10.11/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/rodolphe/.vscode-server/extensions/ms-python.python-2023.14.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/home/rodolphe/disen/disen/triton_temp.py", line 60, in <module>
output_triton = triton_cln(x)
File "/home/rodolphe/disen/disen/triton_temp.py", line 51, in triton_cln
cln_kernel[(batch_size,)](x, output, x.stride(0), output.stride(0), num_channels, seq_length, BLOCK_SIZE=BLOCK_SIZE)
File "<string>", line 41, in cln_kernel
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 1621, in compile
next_module = compile(module)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 1550, in <lambda>
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File "/home/rodolphe/.cache/pypoetry/virtualenvs/disen-XvWdDkUP-py3.10/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 18:8:
def cln_kernel(
input_ptr,
output_ptr,
input_stride,
output_stride,
num_channels: tl.constexpr,
seq_length: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # 1 pid par élément du batch
line_offset = pid * input_stride
batch_element = tl.zeros((num_channels, seq_length), dtype=tl.float32)
for channel_idx in range(0, num_channels):
channel_offset = channel_idx * seq_length
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < seq_length
channel_row = tl.load(input_ptr + line_offset + channel_offset + offsets, mask=mask)
batch_element[channel_idx] = channel_row
^
Can you write out the full copy of your algorithm in pseudocode? Currently, batch_element serves no purpose in the code snippet.
If you could provide this that would be great. If your idea is to write out the accumulated statistics, if you are thinking of an RNN use-case, it may have to go into global memory.
I would only store the per-instance (and possibly per-channel) statistics (depends whether you are computing channel-wise or cross-channel layer norm)
I can't imagine why you would want to persistently materialize the rows per channel. Where would they be used for further computation?
Thank you for your interest. Here is what I want to do :
def fast_cln(input, lengths):
b, c, t = input.shape
z = temporal_cumulative_mean(input, lengths).view(b, 1, t)
gamma = temporal_cumulative_mean(input.pow(2), lengths).view(b, 1, t)
std = (gamma - z.pow(2) + 1e-8).sqrt()
x = (input - z) / std
return x
def temporal_cumulative_mean(x: torch.Tensor, lengths: torch.Tensor):
return torch.cumsum(torch.mean(x, dim=1), dim=1) / lengths
B, C, T = (80, 64, 301)
lengths = torch.arange(1, T+1, device="cuda")
input = torch.randn((B, C, T), device="cuda")
fast_cln(input, lengths)
In this case, T is static. So you should just be able to run layer norm on the entire range of T? Or is the cumulative mean exponential averaging, while T is in fact dynamic?
This normalisation comes from this article https://arxiv.org/pdf/1809.07454.pdf p12 equation (12) This normalization is causal so at any point in time, I can't use values from the future. That's why a standard layer norm on the channel and time dimensions won't make it.
Right. Might I suggest to simply materialize the statistics (mean, var) per channel instead of entire row, and furthermore, store only the latest accumulated (not historical)? This can then be applied to given channel at time T, and reused in computation of T+1.
You would then run a loop over the range of T inside each kernel. In every loop iteration, you will load the un-normalized rows for T, compute the new cumulative stats, apply them to obtain the normalized rows, and store the normalized rows for T.
Further, may I suggest trying to parallelize the kernel across channels via the use of the program_id?
I have an easy way to compute the running mean and variance as follows :
I denote $(f{i, t}){i, t}$ the input with $i$ being the channel index ($1\leq i\leq N$) and $t$ the time index.
I denote $(z_t)_t$ the temporal cumulative mean defined by :
$$zT = \frac{1}{NT}\sum{i, t\leq T} f_{i, t}$$
I also define $\gamma_T$ : $$\gammaT = \frac{1}{NT}\sum{i, t\leq T} f_{i, t}^2$$
I want to compute the cumulative mean variance defined by : $$VT = \frac{1}{NT} \sum{i, t\leq T} (f_{i, t} - z_T)^2$$ Through some computations, one can derive that $$V_T = \gamma_T - z_T^2$$ And, using the formulas of $z_T$ and $\gamma_T$, one can compute $\gamma_T$ and $zT$ by recurence : $$z{T+1} = \frac{T}{T+1}\times z_T + \frac{1}{N(T+1)}\sumi f{i, T+1}$$ $$\gamma_{T+1} = \frac{T}{T+1}\times \gamma_T + \frac{1}{N(T+1)}\sumi f{i, T+1}^2$$
So, I agree with you that we could loop over time and compute only the current mean and variance.
What I don't know is how to do it in triton.
If I understand well, I first create a grid of B blocks (B being the batch_size) because the computations are made for each batch element, it's totally parallelizable through the batch dimension. And then, inside the kernel, I'm manipulating a (C, T) tensor. That's where I'm a bit lost, how do I create and manipulate the $z_T$ and $\gamma_T$ ?! In this kernel, there will be double loops (one over the channel index and one over the time index) ?
I'm trying to implement a cumulative layer norm using triton. I'm manipulating 3d tensors of shape (batch_size, channels, seq_length) and to begin with I just want to implement the "identity" and store an intermediate variable in the share memory (see the comments inside the code) :