mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.1k stars 1.57k forks source link

[Bug] RWKV v6 models fail to compile with latest mlc_llm #2873

Open MollySophia opened 2 months ago

MollySophia commented 2 months ago

šŸ› Bug

RWKV v6 models fail to compile with latest mlc_llm.

Edit: Also it seems that there's currently only rwkv v5 compiling test in ci. Should rwkv v6 be added in ci test too?

To Reproduce

Steps to reproduce the behavior:

  1. Get https://huggingface.co/RWKV/rwkv-6-world-1b6
  2. $ mlc_llm convert_weight rwkv-6-world-1b6 --quantization q4f16_1 -o rwkv-6-world-1b6-MLC
  3. $ mlc_llm gen_config rwkv-6-world-1b6 --quantization q4f16_1 --conv-template rwkv_world -o rwkv-6-world-1b6-MLC
  4. $ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so (or build with other hosts)
It gets the following error messages: (Click to expand) ``` $ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so [2024-09-02 10:57:53] INFO auto_config.py:70: Found model configuration: rwkv-6-world-1b6-MLC/mlc-chat-config.json [2024-09-02 10:57:54] INFO auto_device.py:79: Found device: metal:0 [2024-09-02 10:57:54] INFO auto_target.py:78: Found configuration of target device "metal:0": {"thread_warp_size": runtime.BoxInt(32), "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]} [2024-09-02 10:57:54] INFO auto_target.py:114: Using LLVM triple specified by --host: arm64-apple-darwin [2024-09-02 10:57:54] INFO auto_config.py:154: Found model type: rwkv6. Use `--model-type` to override. Compiling with arguments: --config RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={}) --quantization GroupQuantize(name='q4f16_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float16', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0) --model-type rwkv6 --target {"thread_warp_size": runtime.BoxInt(32), "host": {"kind": "llvm", "tag": "", "keys": ["arm_cpu", "cpu"], "mtriple": "arm64-apple-darwin"}, "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]} --opt flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE --system-lib-prefix "" --output rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so --overrides context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=None;pipeline_parallel_stages=None [2024-09-02 10:57:54] INFO compile.py:140: Creating model from: RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={}) [2024-09-02 10:57:54] INFO compile.py:158: Exporting the model to TVM Unity compiler [2024-09-02 10:57:57] INFO compile.py:164: Running optimizations using TVM Unity [2024-09-02 10:57:57] INFO compile.py:185: Registering metadata: {'model_type': 'rwkv6', 'quantization': 'q4f16_1', 'context_window_size': -1, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 4096, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'rnn_state', 'max_batch_size': 80} [2024-09-02 10:57:57] INFO pipeline.py:54: Running TVM Relax graph-level optimizations [2024-09-02 10:57:59] INFO pipeline.py:54: Lowering to TVM TIR kernels [2024-09-02 10:58:04] INFO pipeline.py:54: Running TVM TIR-level optimizations [2024-09-02 10:58:22] INFO pipeline.py:54: Running TVM Dlight low-level optimizations [2024-09-02 10:58:27] INFO pipeline.py:54: Lowering to VM bytecode [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `alloc_embedding_tensor`: 16.00 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_decode`: 106.57 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_prefill`: 293.50 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_verify`: 273.75 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `create_rnn_state`: 0.00 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `decode`: 1.32 MB [2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `embed`: 16.00 MB [2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `prefill`: 273.75 MB [2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `softmax_with_temperature`: 0.00 MB [2024-09-02 10:58:32] INFO pipeline.py:54: Compiling external modules [2024-09-02 10:58:32] INFO pipeline.py:54: Compilation complete! Exporting to disk Traceback (most recent call last): File "/Users/molly/miniconda3/envs/mlc-llm-latest/bin/mlc_llm", line 8, in sys.exit(main()) ^^^^^^ File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/__main__.py", line 33, in main cli.main(sys.argv[2:]) File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/cli/compile.py", line 129, in main compile( File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 243, in compile _compile(args, model_config) File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 188, in _compile args.build_func( File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/support/auto_target.py", line 311, in build relax.build( File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 341, in build return _vmlink( ^^^^^^^^ File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 247, in _vmlink lib = tvm.build( ^^^^^^^^^^ File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/driver/build_module.py", line 297, in build rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__ File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err tvm._ffi.base.TVMError: Traceback (most recent call last): Did you forget to bind? Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments. Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments. Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments. Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments. File "/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/tir/analysis/verify_memory.cc", line 205 RuntimeError: Memory verification failed with the following errors: # from tvm.script import tir as T @T.prim_func def fused_matmul4_tir_tanh2(p_add768: T.handle, model_blocks_0_attention_time_maa_w13: T.Buffer((2048, 160), "float16"), p_output0: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "arm64-apple-darwin", "tag": ""}, "keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int32() add768 = T.match_buffer(p_add768, (1, seq_len, 2048), "float16") compute_intermediate = T.match_buffer(p_output0, (1, seq_len, 160), "float16") add768_pad = T.allocate([(seq_len + 1) // 2 * 4096], "float16", "global") matmul_intermediate_pad_rf = T.allocate([(seq_len + 1) // 2 * 10240], "float16", "global") matmul_intermediate_pad_rf_1 = T.allocate([(seq_len + 1) // 2 * 1280], "float16", "global") matmul_intermediate_pad = T.allocate([(seq_len + 1) // 2 * 320], "float16", "global") add768_pad_1 = T.allocate([(seq_len + 3) // 4 * 8192], "float16", "global") matmul_intermediate_pad_rf_2 = T.allocate([(seq_len + 3) // 4 * 20480], "float16", "global") matmul_intermediate_pad_rf_3 = T.allocate([(seq_len + 3) // 4 * 2560], "float16", "global") matmul_intermediate_pad_1 = T.allocate([(seq_len + 3) // 4 * 640], "float16", "global") add768_1 = T.Buffer((seq_len * 2048,), "float16", data=add768.data) model_blocks_0_attention_time_maa_w13_1 = T.Buffer((327680,), "float16", data=model_blocks_0_attention_time_maa_w13.data) compute_intermediate_1 = T.Buffer((seq_len * 160,), "float16", data=compute_intermediate.data) if T.tvm_thread_invariant(seq_len <= 2): add768_pad_2 = T.Buffer(((seq_len + 1) // 2 * 4096,), "float16", data=add768_pad) for ax0 in range((seq_len + 1) // 2 * 2): if ax0 < seq_len: for ax1 in range(2048): cse_var_1: T.int32 = ax0 * 2048 + ax1 add768_pad_2[cse_var_1] = add768_1[cse_var_1] else: for ax1 in range(2048): add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0) matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 1) // 2 * 10240,), "float16", data=matmul_intermediate_pad_rf) with T.launch_thread("blockIdx.y", (seq_len + 1) // 2) as blockIdx_y: blockIdx_x = T.launch_thread("blockIdx.x", 3) threadIdx_x = T.launch_thread("threadIdx.x", 64) threadIdx_y = T.launch_thread("threadIdx.y", 4) for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 2, 2): if blockIdx_x * 2 + threadIdx_x // 32 < 5: if ax2_fused_0 == 0 and ax2_fused_2 == 0: matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = T.Broadcast(T.float16(0.0), 4) matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] + add768_pad_2[blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160] matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 1) // 2 * 1280,), "float16", data=matmul_intermediate_pad_rf_1) for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4, 8): cse_var_4: T.int32 = ax0_0 * 320 cse_var_3: T.int32 = ax0_1 * 160 cse_var_2: T.int32 = ax1_fused_0 * 64 if ax2_fused_1_ax2_fused_3_fused_1 == 0: matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = T.float16(0.0) matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 2560 + cse_var_4 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] matmul_intermediate_pad_2 = T.Buffer(((seq_len + 1) // 2 * 320,), "float16", data=matmul_intermediate_pad) for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4): cse_var_8: T.int32 = ax0_0 * 320 cse_var_7: T.int32 = ax0_1 * 160 cse_var_6: T.int32 = ax1_fused_0 * 64 cse_var_5: T.int32 = cse_var_8 + cse_var_7 + cse_var_6 + ax1_fused_1 if ax2_fused_1_ax2_fused_3_fused_0 == 0: matmul_intermediate_pad_2[cse_var_5] = T.float16(0.0) matmul_intermediate_pad_2[cse_var_5] = matmul_intermediate_pad_2[cse_var_5] + matmul_intermediate_pad_rf_5[cse_var_8 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_7 + cse_var_6 + ax1_fused_1] for ax0, ax1 in T.grid(seq_len, 160): cse_var_9: T.int32 = ax0 * 160 + ax1 compute_intermediate_1[cse_var_9] = T.tanh(matmul_intermediate_pad_2[cse_var_9]) else: if T.tvm_thread_invariant(seq_len <= 8): add768_pad_2 = T.Buffer(((seq_len + 3) // 4 * 8192,), "float16", data=add768_pad_1) for ax0 in range((seq_len + 3) // 4 * 4): if ax0 < seq_len: for ax1 in range(2048): cse_var_10: T.int32 = ax0 * 2048 + ax1 add768_pad_2[cse_var_10] = add768_1[cse_var_10] else: for ax1 in range(2048): add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0) matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 3) // 4 * 20480,), "float16", data=matmul_intermediate_pad_rf_2) with T.launch_thread("blockIdx.y", (seq_len + 3) // 4) as blockIdx_y: blockIdx_x = T.launch_thread("blockIdx.x", 3) threadIdx_x = T.launch_thread("threadIdx.x", 64) threadIdx_y = T.launch_thread("threadIdx.y", 4) for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 4, 2): if blockIdx_x * 2 + threadIdx_x // 32 < 5: if ax2_fused_0 == 0 and ax2_fused_2 == 0: matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = T.Broadcast(T.float16(0.0), 4) matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] + add768_pad_2[blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160] matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 3) // 4 * 2560,), "float16", data=matmul_intermediate_pad_rf_3) for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4, 8): cse_var_13: T.int32 = ax0_0 * 640 cse_var_12: T.int32 = ax0_1 * 160 cse_var_11: T.int32 = ax1_fused_0 * 64 if ax2_fused_1_ax2_fused_3_fused_1 == 0: matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = T.float16(0.0) matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 5120 + cse_var_13 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] matmul_intermediate_pad_2 = T.Buffer(((seq_len + 3) // 4 * 640,), "float16", data=matmul_intermediate_pad_1) for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4): cse_var_17: T.int32 = ax0_0 * 640 cse_var_16: T.int32 = ax0_1 * 160 cse_var_15: T.int32 = ax1_fused_0 * 64 cse_var_14: T.int32 = cse_var_17 + cse_var_16 + cse_var_15 + ax1_fused_1 if ax2_fused_1_ax2_fused_3_fused_0 == 0: matmul_intermediate_pad_2[cse_var_14] = T.float16(0.0) matmul_intermediate_pad_2[cse_var_14] = matmul_intermediate_pad_2[cse_var_14] + matmul_intermediate_pad_rf_5[cse_var_17 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_16 + cse_var_15 + ax1_fused_1] for ax0, ax1 in T.grid(seq_len, 160): cse_var_18: T.int32 = ax0 * 160 + ax1 compute_intermediate_1[cse_var_18] = T.tanh(matmul_intermediate_pad_2[cse_var_18]) else: blockIdx_z = T.launch_thread("blockIdx.z", 1) matmul_intermediate_reindex_pad_metal_simdgroup = T.allocate([256], "float16", "metal.simdgroup") add768_reindex_pad_shared = T.allocate([512], "float16", "shared") model_blocks_0_attention_time_maa_w13_reindex_pad_shared = T.allocate([2048], "float16", "shared") add768_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup") model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup") blockIdx_x = T.launch_thread("blockIdx.x", (seq_len + 15) // 16) blockIdx_y = T.launch_thread("blockIdx.y", 3) threadIdx_x = T.launch_thread("threadIdx.x", 32) threadIdx_y = T.launch_thread("threadIdx.y", 1) threadIdx_z = T.launch_thread("threadIdx.z", 4) for ax1_2_init, ax2_2_init in T.grid(2, 2): T.make_filled_simdgroup_matrix(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_2_init * 2 + ax2_2_init, T.float32(0.0), 8, 8) for ax3_0 in range(64): add768_reindex_pad_shared_1 = T.Buffer((512,), "float16", data=add768_reindex_pad_shared, scope="shared") add768_reindex_pad_shared_1[threadIdx_z * 128 + threadIdx_x * 4:threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_x * 16 + threadIdx_z * 4 + threadIdx_x // 8 < seq_len, add768_1[blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4:blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4 + 4], T.Broadcast(T.float16(0.0), 4)) for ax1_ax2_fused_0 in range(4): model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((2048,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared") model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_y * 2 + ax1_ax2_fused_0 // 2 < 5, model_blocks_0_attention_time_maa_w13_1[ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8:ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8 + 640:160], T.Broadcast(T.float16(0.0), 4)) for ax3_1 in range(4): for ax0_0 in range(2): T.simdgroup_load(add768_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), add768_reindex_pad_shared, ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(False)) for ax0_0 in range(2): T.simdgroup_load(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, threadIdx_z * 512 + ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(True)) for ax1_2, ax2_2 in T.grid(2, 2): cse_var_19: T.int32 = ax1_2 * 2 + ax2_2 T.simdgroup_multiply_accumulate(matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19, add768_reindex_pad_shared_metal_simdgroup, ax1_2, model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax2_2, matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19) for ax1_0, ax2_0 in T.grid(2, 2): T.simdgroup_store(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_0 * 2 + ax2_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, ax1_0 * 512 + threadIdx_z * 16 + ax2_0 * 8, 512, 2), 64, 8, 8, T.bool(False)) for ax1_ax2_fused_0 in range(2): if blockIdx_x * 16 + ax1_ax2_fused_0 * 8 + threadIdx_z * 2 + threadIdx_x // 16 < seq_len and blockIdx_y * 2 + threadIdx_x % 16 // 8 < 5: model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((1024,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared") compute_intermediate_1[blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4:blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4 + 4] = T.tanh(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4]) ```

Expected behavior

Model lib successfully compiles.

Environment

MasterJH5574 commented 2 months ago

Thank you for reporting. We'll look into this.