sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.33k stars 68 forks source link

benchmark_training_throughput and bugs #32

Closed rakkit closed 3 months ago

rakkit commented 4 months ago

Hi, Thanks for your great work. I ran benchmarks to test all modes' throughput and memory usage with code flash-linear-attention/benchmarks/benchmark_training_throughput.py. Some of them, unfortunately, failed.

WORK:

FAILED

(For mamba, i notice FLA did not asking for mamba_ssm and causal_conv1d, but also did not raise any warning that it runs on slow-forward mode)

Here are benchmark results, and the error info of the failed run is attached at the end.

Environment:

NVIDIA A100-SXM4-40GB
NVIDIA-SMI 550.54.15             
Driver Version: 550.54.15      
CUDA Version: 12.4 
torch                    2.3.1
accelerate               0.32.1
transformers             4.42.4
triton                   2.3.1
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.5.82
nvidia-nvtx-cu12         12.1.105
Model Batch-Size Seq_len Max-memory (GB) Throughput (tokens/s)
GLA 8 512 14.77 14959.27
    1024 22.88 17467.75
    2048 OOM  
GSA 8 512 16.07 14960.99
    1024 24.35 17674.00
    2028 OOM  
HGRN 8 512 16.9 16382.00
    1024 26.15 19500.58
    2048 OOM  
retnet 8 512 15.13 13369.01
    1024 22.66 15437.14
    2048 37.75 16445.58
transformer 8 512 13.98 17851.42
    1024 20.30 20994.52
    2048 32.96 21807.02
Mamba 8 512 15.40 10385.55
    1024 22.72 11230.94
    2048 37.36 12151.64
Samba 8 512 13.77 16475.11
    1024 19.56 18470.86
    2048 31.18 19850.30

Delta-Net

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 385, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 263, in forward
    hidden_states, attentions, past_key_values = layer(
                                                 ^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/delta_net/modeling_delta_net.py", line 117, in forward
    hidden_states, attentions, past_key_values = self.attn(
                                                 ^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/delta_net.py", line 228, in forward
    o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 543, in chunk_delta_rule
    o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT,  initial_state, output_final_state)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/utils.py", line 11, in wrapper
    return fn(ctx,
           ^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 503, in forward
    h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)        
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/delta_rule/chunk.py", line 394, in chunk_fwd_h_fn
    assert BK <= 256, "current kernel does not support head dimension larger than 256."
           ^^^^^^^^^
AssertionError: current kernel does not support head dimension larger than 256.
----------------------------------------------

HGRN2

Traceback (most recent call last):
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
    generator.visit(fn.parse())
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 426, in generic_visit
    self.visit(item)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
               ^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 885, in visit_For
    self.visit_compound_statement(node.body)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
               ^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 448, in visit_AugAssign
    self.visit(assign)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 428, in visit_Assign
    values = self.visit(node.value)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 477, in visit_BinOp
    rhs = self.visit(node.right)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/p/software/machine/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1027, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/core.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/core.py", line 1018, in dot
    return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1207, in dot
    assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1183, in assert_dtypes_valid
    assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
           ^^^^^^^^^^^^^^^^^^^^^^
AssertionError: First input (fp32) and second input (bf16) must have the same dtype!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 372, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 248, in forward
    hidden_states, attentions, past_key_values = layer(
                                                 ^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/hgrn2/modeling_hgrn2.py", line 97, in forward
    hidden_states, attentions, past_key_values = self.attn(
                                                 ^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/hgrn2.py", line 151, in forward
    o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 733, in chunk_gla
    o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/utils.py", line 11, in wrapper
    return fn(ctx,
           ^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 541, in forward
    h = fwd_inner(
        ^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/ops/gla/chunk.py", line 514, in fwd_inner
    chunk_gla_fwd_kernel_h[grid](
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
                              ^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 191, in compile
    module = src.make_ir(options)
             ^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 117, in make_ir
    return ast_to_ttir(self.fn, self, options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 53:27:        # [BT, BV]
        b_v = tl.load(p_v, boundary_check=(0, 1))
        # [BK, BT]
        b_g = tl.load(p_g, boundary_check=(0, 1))
        if i_t < NT - 1:
            # [BK,]
            b_gn = tl.load(p_gn, boundary_check=(0,))
        else:
            b_gn = tl.min(b_g, axis=1)
        b_h *= tl.exp(b_gn)[:, None]
        b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
        b_h += tl.dot(b_k, b_v, allow_tf32=False)
                           ^
AssertionError('First input (fp32) and second input (bf16) must have the same dtype!')

Linear_attention

Traceback (most recent call last):
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 100, in <module>
    profile(args.name, args.batch_size, args.seq_len, args.warmup_steps, args.steps)
  File "system/fla-bench/flash-linear-attention/benchmarks/benchmark_training_throughput.py", line 66, in profile
    outputs = model(tokens, labels=tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 389, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 268, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/models/linear_attn/modeling_linear_attn.py", line 105, in forward
    hidden_states = self.attn(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/layers/linear_attn.py", line 129, in forward
    q = self.feature_map_q(q)
        ^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/flash-linear-attention/fla/modules/feature_map.py", line 130, in forward
    return self.layer1(x) * self.layer2(x)
           ^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "system/fla-bench/llm_env/venv/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x2048 and 512x512)

Rwkv6

OOM for BS=8

 File "system/fla-bench/flash-linear-attention/fla/layers/rwkv6.py", line 272, in forward
    return self.linear(x + delta * mu)
                       ~~^~~~~~~~~~~~
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 
rakkit commented 3 months ago

And extra benchmark on H100 (SXM5 94GB)

Model Batch-Size Seq_len Max-memory (GB) Throughput (tokens/s)On A100 Throughput H100
GLA 8 512 14.77 14959.27 31704.30
    1024 22.88 17467.75 37256.27
    2048 OOM 41293.03
GSA 8 512 16.07 14960.99 31672.98
    1024 24.35 17674.00 37659.39
    2028 OOM 41753.57
HGRN 8 512 16.9 16382.00 35731.41
    1024 26.15 19500.58 42272.09
    2048 OOM 49234.70
retnet 8 512 15.13 13369.01 27383.20
    1024 22.66 15437.14 31250.06
    2048 37.75 16445.58 34090.15
transformer 8 512 13.98 17851.42 40468.24
    1024 20.30 20994.52 47119.66
    2048 32.96 21807.02 49627.74
Mamba 8 512 15.40 10385.55 19980.05
    1024 22.72 11230.94 21674.81
    2048 37.36 12151.64 23410.73
Samba 8 512 13.77 16475.11 33476.98
    1024 19.56 18470.86 38945.52
    2048 31.18 19850.30 42370.31

image

yzhangcs commented 3 months ago

@rakkit Hi, thank you for reporting the bugs. The new commits have fixed those you mentioned. However, HGRN2 and LinearAttn are normal for me. Please check your triton/causal-conv1d versions.

sustcsonglin commented 3 months ago

@rakkit Hi, thank you for reporting the bugs. The new commits have fixed those you mentioned. However, HGRN2 and LinearAttn are normal for me. Please check your triton/causal-conv1d versions.

I've fixed the HGRN2 bug in https://github.com/sustcsonglin/flash-linear-attention/commit/a7bb4b7f71bec43d72a7486436d4b837b44e4333: AMP will secretly convert keys to float32 due to the use of sigmoid. Now it should be good

rakkit commented 3 months ago

Hi, @yzhangcs @sustcsonglin. Thanks for fixing the bug so quickly. I can confirm all models are working now.

There is a minor issue in benchmark code. The seq_len is not passed to config, so config.max_position_embeddings will be the default. For models such as transformers will fail in long-sequence benchmark (seq > max_position_embeddings)

The following are the full benchmark results for NVIDIA A100-SXM4-40GB and NVIDIA H100 PCIe-80GB.

(I plot BS=1 Only here) output

models BS Seq-len A100 H100
samba 1 32768   28904.69
retnet 1 32768   22108.37
delta-net 1 32768   20419.72
mamba 1 32768   18535.8
transformer 1 32768   12216.43
samba 2 16384   27565.2
retnet 2 16384   23154.44
delta-net 2 16384   21761.01
transformer 2 16384   18647.61
mamba 2 16384   16121.99
hgrn 1 16384   30473.15
samba 1 16384 20055.28 27931.16
gsa 1 16384   25442.21
hgrn2 1 16384   24075.14
retnet 1 16384   21964.05
linear-attn 1 16384   21414
delta-net 1 16384   20194.1
rwkv6 1 16384   18773.85
transformer 1 16384   18242.55
mamba 1 16384 13811.72 18017
samba 4 8192   27906.56
transformer 4 8192   25004.98
retnet 4 8192   24045.87
delta-net 4 8192   22723.17
mamba 4 8192   16203.86
hgrn 2 8192   31345.5
samba 2 8192 19210.89 27332.59
gsa 2 8192   26957.65
hgrn2 2 8192   24674.07
transformer 2 8192   24419.43
retnet 2 8192 16247.82 22721.79
linear-attn 2 8192   22010.4
delta-net 2 8192   21369.47
rwkv6 2 8192   18945.29
mamba 2 8192 11941.12 15749.19
hgrn 1 8192 18675.9 27612.78
samba 1 8192 18699.51 26528.8
gsa 1 8192 16001.26 23310.63
transformer 1 8192   22464.27
hgrn2 1 8192 15501.05 22117.39
retnet 1 8192 14617.58 20519.9
linear-attn 1 8192 14021.22 19380.21
delta-net 1 8192 13352.6 18804.9
rwkv6 1 8192   17398.72
mamba 1 8192 13288.27 17014.34
transformer 8 4096   30320.49
samba 8 4096   28164.83
retnet 8 4096   24162.65
delta-net 8 4096   22879.71
mamba 8 4096   16240.19
hgrn 4 4096   31571.14
transformer 4 4096   29370.23
gsa 4 4096   27980.19
samba 4 4096 19486.16 27797.24
hgrn2 4 4096   24948.31
retnet 4 4096 16406.06 23430.21
linear-attn 4 4096   22187.43
delta-net 4 4096   22106.26
rwkv6 4 4096   19107.01
mamba 4 4096 12146.53 15883.88
hgrn 2 4096 19252.6 28482.19
transformer 2 4096   26838.27
samba 2 4096 18017.55 25800.91
gsa 2 4096 17027 24598.84
hgrn2 2 4096 15847.4 22766.22
retnet 2 4096 15268.96 21436.01
linear-attn 2 4096 14599.34 20324.97
delta-net 2 4096 14053.8 19777.87
rwkv6 2 4096   17706.11
mamba 2 4096 11631.5 15203.35
hgrn 1 4096 15768.39 23586.86
samba 1 4096 16493.58 23347.85
transformer 1 4096   23195.83
hgrn2 1 4096 13383.31 19303.13
retnet 1 4096 12748.79 18549.01
gsa 1 4096 13751.03 18382.57
linear-attn 1 4096 12239.79 17267.57
delta-net 1 4096 11686 16246.66
rwkv6 1 4096 10452.14 15948.46
mamba 1 4096 11841.48 15932.17
transformer 16 2048   33974.39
samba 16 2048   28634.4
retnet 16 2048   24150
delta-net 16 2048   22971.42
mamba 16 2048   16222.58
transformer 8 2048 21955.1 32656.87
hgrn 8 2048   31816.45
gsa 8 2048   28157.97
samba 8 2048 19838.32 27929.99
hgrn2 8 2048   25083.31
retnet 8 2048 16564.31 23576.51
linear-attn 8 2048   22438.17
delta-net 8 2048   22272.51
rwkv6 8 2048   19198.61
mamba 8 2048 12176.65 15861.45
transformer 4 2048 20264.96 29556.21
hgrn 4 2048 19429.87 28679.4
samba 4 2048 18297.54 26065.31
gsa 4 2048 17566.37 25574.87
hgrn2 4 2048 15941.95 22907.09
linear-attn 4 2048 14663.66 20819.64
delta-net 4 2048 14558.07 20432.8
retnet 4 2048 15405.69 20195.56
rwkv6 4 2048   17712.64
mamba 4 2048 11674.9 15177.81
transformer 2 2048 16961.03 25506.14
hgrn 2 2048 16207.16 24051.77
samba 2 2048 16119.07 22902.87
gsa 2 2048 14500.91 21552.52
hgrn2 2 2048 13645 19613.37
retnet 2 2048 13257.27 19379.32
linear-attn 2 2048 12667.81 17915.54
delta-net 2 2048 12222.73 17606.42
rwkv6 2 2048 10511.13 16026.85
mamba 2 2048 10533.79 14058.06
transformer 1 2048 13270.76 20094.04
hgrn 1 2048 12584.37 18807.96
samba 1 2048 13280.36 18106.8
retnet 1 2048 10484.34 14925.46
linear-attn 1 2048 9974.69 14619.53
gsa 1 2048 10429.46 13602.88
mamba 1 2048 9746.15 12487
hgrn2 1 2048 10932.83 12090.52
delta-net 1 2048 9735.89 11350.95
rwkv6 1 2048 8658.36 7733.45
transformer 32 1024   36088.91
samba 32 1024   29441.18
retnet 32 1024   23971.72
delta-net 32 1024   23042.36
mamba 32 1024   16431.25
transformer 16 1024 22978.64 34667.33
hgrn 16 1024   31794.96
samba 16 1024 19938.37 28706.31
gsa 16 1024   28277.14
hgrn2 16 1024   25243.74
retnet 16 1024 16574.38 23523.83
linear-attn 16 1024   22535.4
delta-net 16 1024   22520.52
rwkv6 16 1024   19232
mamba 16 1024   16063.34
transformer 8 1024 21132.73 31095.91
hgrn 8 1024 19370.2 28536.46
samba 8 1024 18429.84 26916.36
gsa 8 1024 17816.01 25694.83
hgrn2 8 1024 16036.75 23084.95
retnet 8 1024 15528.16 22135.69
linear-attn 8 1024 14763.62 20849.1
delta-net 8 1024 14605.84 20556.72
rwkv6 8 1024   17765.78
mamba 8 1024 11251.76 15380.6
transformer 4 1024 17611.25 26917.37
hgrn 4 1024 16157.12 23891.95
samba 4 1024 16246.94 23271.67
hgrn2 4 1024 13711.21 20287.58
gsa 4 1024 14901.37 19226.52
retnet 4 1024 13333.89 19125.69
linear-attn 4 1024 12735.54 18367.85
delta-net 4 1024 12592.95 18226.03
rwkv6 4 1024 10564.47 16148.39
mamba 4 1024 10229.23 14218.26
samba 2 1024 13023.25 17978.55
hgrn2 2 1024 11099.64 16062.79
hgrn 2 1024 12677.8 15324.84
delta-net 2 1024 9957.74 14141.77
linear-attn 2 1024 10246.18 13032.14
retnet 2 1024 10804.88 13003.05
rwkv6 2 1024 8714.04 12569.48
mamba 2 1024 8589.29 11972.97
gsa 2 1024 10758.39 10844.75
samba 1 1024 9545.75 9991.64
retnet 1 1024 6998.4 9003.62
hgrn 1 1024 7720.77 7824.65
linear-attn 1 1024 7059.6 7195.38
hgrn2 1 1024 6435.73 6144.78
mamba 1 1024 6892.87 6058.78
delta-net 1 1024 5650.93 5780.75
rwkv6 1 1024 4814.52 4774.23
gsa 1 1024 5319.96 4432.29
transformer 32 512 23491.27 35695.48
hgrn 32 512   31913.77
samba 32 512 20188.05 28949.67
gsa 32 512   28433.56
hgrn2 32 512   25231.68
retnet 32 512 16544.75 23569.82
delta-net 32 512   22614.96
linear-attn 32 512   22573.86
rwkv6 32 512   19212.5
mamba 32 512   16269.5
transformer 16 512 21583.01 31914.11
hgrn 16 512 19453.76 28698.38
samba 16 512 18714.74 27012.99
gsa 16 512 17800.15 25769.91
hgrn2 16 512 16091.78 23190.33
retnet 16 512 15507.6 22123.85
linear-attn 16 512 14759.27 20859.9
delta-net 16 512 14658.35 20794.48
rwkv6 16 512   17757.09
mamba 16 512   15530.05
transformer 8 512 17959.34 26990.47
hgrn 8 512 16262.18 23961.92
samba 8 512 16445.8 23652.83
hgrn2 8 512 13781.91 20332.45
retnet 8 512 13427.81 19785.93
linear-attn 8 512 12819.35 18371.43
gsa 8 512 15056.07 17805.26
delta-net 8 512 12641.08 17178.97
rwkv6 8 512 10603.23 15986.87
mamba 8 512 10396.3 14451.54
samba 4 512 13159.76 18081.48
transformer 4 512 13941.52 17548.94
retnet 4 512 10853.05 15515.85
hgrn 4 512 12714.64 15507.94
linear-attn 4 512 10287.85 15261.04
delta-net 4 512 10165.85 14407.68
hgrn2 4 512 11159.17 12823.84
rwkv6 4 512 8760.7 12710.08
mamba 4 512 8676.28 12335.67
gsa 4 512 10788.66 10845.14
hgrn 2 512 7804.57 9850.05
transformer 2 512 8624.15 8912.99
linear-attn 2 512 6974.66 7168.91
retnet 2 512 6943.15 7048.41
mamba 2 512 6607.24 6737.07
hgrn2 2 512 6440.37 5409.62
delta-net 2 512 5715.89 5383.6
gsa 2 512 5339.63 5336.48
rwkv6 2 512 4889.76 4554
samba 1 512 4964.62 5028.2
hgrn2 1 512 3215.79 4092.26
transformer 1 512 4371.82 3979.84
hgrn 1 512 3903.4 3925.22
linear-attn 1 512 3556.24 3670.65
retnet 1 512 3531.6 3574.41
mamba 1 512 3484.71 3459.5
delta-net 1 512 2884.54 2886.16
gsa 1 512 2666.12 2700.98
rwkv6 1 512 2428.59 1955.17
yzhangcs commented 3 months ago

@rakkit Thanks for your detailed benchmarks.

There is a minor issue in benchmark code. The seq_len is not passed to config, so config.max_position_embeddings will be the default. For models such as transformers will fail in long-sequence benchmark (seq > max_position_embeddings)

We will fix it soon, :>