mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.08k stars 1.56k forks source link

[Bug] LLama-13B with ROCm backend generates different result each time (using MI210) #903

Closed tiandi111 closed 11 months ago

tiandi111 commented 1 year ago

🐛 Bug

I'm compiling LLama-13b-hf model with MLC-LLM ROCm backend and found that the compiled model outputs different value each time. We found that the error was introduced mainly from matmul ops.

To Reproduce

Steps to reproduce the behavior:

  1. Compile model with the following script
import mlc_llm

build_args = mlc_llm.BuildArgs(
    model="llama-13b-hf",
    artifact_path="/home/tiandi05",
    quantization="q0f16",
    target="rocm",
    debug_dump=True,
    build_model_only=False,
    use_cache=0)

lib_path, model_path, chat_config_path = mlc_llm.build_model(build_args)
  1. Change the 'artifact_path' and 'model' variable(line86 and 87) to your local paths and run the following script to get a layer-by-layer output check report, the script is expected to output a report file named as 'log_dump1_dump2'.
# Used as reference

import json
import os
from typing import List, Tuple

import numpy as np
import torch
import tvm
from transformers import AutoTokenizer, LlamaTokenizer  # type: ignore[import]
from tvm import relax
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
from tvm.runtime import ShapeTuple
from tvm import meta_schedule as ms

import sys
sys.path.append('../')
from mlc_llm import utils

def check(path1, path2, log_dir="."):
    log_path = os.path.join(log_dir, f"log_{os.path.basename(path1)}_{os.path.basename(path2)}")
    with open(log_path, "w") as log:
        files1 = os.listdir(path1)
        for fname in files1:
            fpath1 = os.path.join(path1, fname)
            fpath2 = os.path.join(path2, fname)
            if os.path.isfile(fpath2) and os.path.isfile(fpath2):
                arr1 = np.load(fpath1)
                arr2 = np.load(fpath2)

                equal = np.allclose(arr1, arr2, rtol=1e-5, atol=1e-5)

                if not equal:
                    log.write(f"Array {fname} did not passed!\n")
                else:
                    log.write(f"Array {fname} passed!\n")
            else:
                log.write(f"Either file1 or file2 not found!\n")

class LibCompare(LibCompareVMInstrument):
    def __init__(self, mod, device, dump_path="."):
        super().__init__(mod, device, verbose=False)
        self.dump_path = dump_path
        self._enable_dump = False
        self._dump_index = 0
        if not os.path.exists(self.dump_path):
            os.makedirs(self.dump_path)

    def enable_dump(self, enable):
        self._enable_dump = enable

    def compare(
        self,
        name: str,
        ref_args: List[tvm.nd.NDArray],
        new_args: List[tvm.nd.NDArray],
        ret_indices: List[int],
    ):  
        if name.startswith("shape_func"):
            return
        if self._enable_dump:
            for i, arg in enumerate(ref_args):
                path = os.path.join(self.dump_path, f"{str(self._dump_index).zfill(5)}_{name}_arg{i}.npy")
                with open(path, 'wb') as f:
                    np.save(f, arg.numpy())
                self._dump_index += 1

def run_model() -> None:
    artifact_path = "/home/mlc-llm/dist/llama-13b-2l-hf-q0f16"
    model = "/home/mlc-llm/dist"

    # Load models
    device = tvm.device("rocm")
    const_params = utils.load_params(artifact_path, device)
    ex = tvm.runtime.load_module(
        f"{artifact_path}/llama-13b-2l-hf-q0f16-rocm.so")
    vm = relax.VirtualMachine(ex, device)

    # Initialize configs
    with open(
        os.path.join(artifact_path, "params", "mlc-chat-config.json"),
        "r",
        encoding="utf-8",
    ) as f:
        config = json.load(f)

    if config["model_category"] == "llama":
        tokenizer = LlamaTokenizer.from_pretrained(
            os.path.join(artifact_path, "params"), trust_remote_code=True
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            os.path.join(artifact_path, "params"), trust_remote_code=True
        )

    # Execute inference
    print("Tokenizing...")
    batch = 1
    input_token_len = 2048
    inputs = tvm.nd.array(
        torch.full((batch, input_token_len), 1).to(torch.int32).numpy(),
        device,
    )
    first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), device)
    seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
    second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])

    # Run twice
    def run_once(dump_path):
        cmp_instrument = LibCompare(ex, device, dump_path)
        vm.set_instrument(cmp_instrument)
        cmp_instrument.enable_dump(True)
        kv_caches = vm["create_kv_cache"]()
        logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params)
        device.sync()
        logits, kv_caches = vm["decode"](
            first_sampled_token, second_seq_len_shape, kv_caches, const_params
        )
        device.sync()

    run_once("./dump1")
    run_once("./dump2")

    check("./dump1", "./dump2")

if __name__ == "__main__":
    run_model()
  1. A example report I got:
    
    Array 00000_take1_arg0.npy passed!
    Array 00001_take1_arg1.npy passed!
    Array 00002_take1_arg2.npy passed!
    Array 00003_extend_te_arg0.npy passed!
    Array 00004_extend_te_arg1.npy passed!
    Array 00005_rms_norm_arg0.npy passed!
    Array 00006_rms_norm_arg1.npy passed!
    Array 00007_rms_norm_arg2.npy passed!
    Array 00008_NT_matmul6_arg0.npy passed!
    Array 00009_NT_matmul6_arg1.npy passed!
    Array 00010_NT_matmul6_arg2.npy passed!
    Array 00011_split1_arg0.npy passed!
    Array 00012_split1_arg1.npy passed!
    Array 00013_split1_arg2.npy passed!
    Array 00014_split1_arg3.npy passed!
    Array 00015_transpose7_arg0.npy passed!
    Array 00016_transpose7_arg1.npy passed!
    Array 00017_transpose7_arg0.npy passed!
    Array 00018_transpose7_arg1.npy passed!
    Array 00019_transpose7_arg0.npy passed!
    Array 00020_transpose7_arg1.npy passed!
    Array 00021_fused_NT_matmul7_divide2_maximum1_minimum1_cast3_arg0.npy passed!
    Array 00022_fused_NT_matmul7_divide2_maximum1_minimum1_cast3_arg1.npy passed!
    Array 00023_fused_NT_matmul7_divide2_maximum1_minimum1_cast3_arg2.npy passed!
    Array 00024_fused_NT_matmul7_divide2_maximum1_minimum1_cast3_arg3.npy passed!
    Array 00025_fused_softmax2_cast4_arg0.npy passed!
    Array 00026_fused_softmax2_cast4_arg1.npy passed!
    Array 00027_matmul10_arg0.npy passed!
    Array 00028_matmul10_arg1.npy passed!
    Array 00029_matmul10_arg2.npy passed!
    Array 00030_transpose9_arg0.npy passed!
    Array 00031_transpose9_arg1.npy passed!
    Array 00032_fused_NT_matmul8_add1_arg0.npy passed!
    Array 00033_fused_NT_matmul8_add1_arg1.npy passed!
    Array 00034_fused_NT_matmul8_add1_arg2.npy passed!
    Array 00035_fused_NT_matmul8_add1_arg3.npy passed!
    Array 00036_rms_norm_arg0.npy passed!
    Array 00037_rms_norm_arg1.npy passed!
    Array 00038_rms_norm_arg2.npy passed!
    Array 00039_NT_matmul9_arg0.npy passed!
    Array 00040_NT_matmul9_arg1.npy passed!
    Array 00041_NT_matmul9_arg2.npy passed!
    Array 00042_fused_split2_silu1_multiply1_arg0.npy passed!
    Array 00043_fused_split2_silu1_multiply1_arg1.npy passed!
    Array 00044_fused_NT_matmul10_add1_arg0.npy passed!
    Array 00045_fused_NT_matmul10_add1_arg1.npy passed!
    Array 00046_fused_NT_matmul10_add1_arg2.npy passed!
    Array 00047_fused_NT_matmul10_add1_arg3.npy passed!
    Array 00048_rms_norm_arg0.npy passed!
    Array 00049_rms_norm_arg1.npy passed!
    Array 00050_rms_norm_arg2.npy passed!
    Array 00051_slice_arg0.npy passed!
    Array 00052_slice_arg1.npy passed!
    Array 00053_fused_NT_matmul5_cast2_arg0.npy passed!
    Array 00054_fused_NT_matmul5_cast2_arg1.npy passed!
    Array 00055_fused_NT_matmul5_cast2_arg2.npy did not passed!
    Array 00056_take_arg0.npy passed!
    Array 00057_take_arg1.npy passed!
    Array 00058_take_arg2.npy passed!
    Array 00059_rms_norm1_arg0.npy passed!
    Array 00060_rms_norm1_arg1.npy passed!
    Array 00061_rms_norm1_arg2.npy passed!
    Array 00062_NT_matmul_arg0.npy passed!
    Array 00063_NT_matmul_arg1.npy passed!
    Array 00064_NT_matmul_arg2.npy did not passed!
    Array 00065_transpose7_arg0.npy did not passed!
    Array 00066_transpose7_arg1.npy did not passed!
    Array 00067_fused_full_NT_matmul1_divide1_maximum_minimum_cast_arg0.npy did not passed!
    Array 00068_fused_full_NT_matmul1_divide1_maximum_minimum_cast_arg1.npy did not passed!
    Array 00069_fused_full_NT_matmul1_divide1_maximum_minimum_cast_arg2.npy did not passed!
    Array 00070_transpose7_arg0.npy did not passed!
    Array 00071_transpose7_arg1.npy did not passed!
    Array 00072_fused_softmax1_cast1_arg0.npy did not passed!
    Array 00073_fused_softmax1_cast1_arg1.npy did not passed!
    Array 00074_matmul9_arg0.npy did not passed!
    Array 00075_matmul9_arg1.npy did not passed!
    Array 00076_matmul9_arg2.npy did not passed!
    Array 00077_fused_NT_matmul2_add_arg0.npy did not passed!
    Array 00078_fused_NT_matmul2_add_arg1.npy passed!
    Array 00079_fused_NT_matmul2_add_arg2.npy passed!
    Array 00080_fused_NT_matmul2_add_arg3.npy did not passed!
    Array 00081_rms_norm1_arg0.npy did not passed!
    Array 00082_rms_norm1_arg1.npy passed!
    Array 00083_rms_norm1_arg2.npy did not passed!
    Array 00084_NT_matmul3_arg0.npy did not passed!
    Array 00085_NT_matmul3_arg1.npy passed!
    Array 00086_NT_matmul3_arg2.npy did not passed!
    Array 00087_fused_split_silu_multiply_arg0.npy did not passed!
    Array 00088_fused_split_silu_multiply_arg1.npy did not passed!
    Array 00089_fused_NT_matmul4_add_arg0.npy did not passed!
    Array 00090_fused_NT_matmul4_add_arg1.npy passed!
    Array 00091_fused_NT_matmul4_add_arg2.npy did not passed!
    Array 00092_fused_NT_matmul4_add_arg3.npy did not passed!
    Array 00093_rms_norm1_arg0.npy did not passed!
    Array 00094_rms_norm1_arg1.npy passed!
    Array 00095_rms_norm1_arg2.npy did not passed!
    Array 00096_fused_NT_matmul5_cast2_arg0.npy did not passed!
    Array 00097_fused_NT_matmul5_cast2_arg1.npy passed!
    Array 00098_fused_NT_matmul5_cast2_arg2.npy did not passed!

## Expected behavior

The compiled model should outputs exact the same value each time.

## Environment

 - Platform: ROCm
 - Operating system: Ubuntu 22.04.3 LTS
 - Device: MI210
 - How you installed MLC-LLM: source
 - How you installed TVM-Unity: source
 - Python version: 3.10.12
 - GPU driver version: 5.18.13
 - CUDA/cuDNN version: N/A
 - TVM Unity Commit: be21b378b284eeab5b6a7721bf417ad7445fddf0 (we cloned from https://github.com/mlc-ai/relax)
 - TVM Unity Hash Tag:

USE_GTEST: AUTO SUMMARIZE: OFF USE_IOS_RPC: OFF USE_MSC: OFF USE_ETHOSU: OFF CUDA_VERSION: NOT-FOUND USE_LIBBACKTRACE: AUTO DLPACK_PATH: 3rdparty/dlpack/include USE_TENSORRT_CODEGEN: OFF USE_THRUST: OFF USE_TARGET_ONNX: OFF USE_AOT_EXECUTOR: ON BUILD_DUMMY_LIBTVM: OFF USE_CUDNN: OFF USE_TENSORRT_RUNTIME: OFF USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF USE_CCACHE: AUTO USE_ARM_COMPUTE_LIB: OFF USE_CPP_RTVM: OFF USE_OPENCL_GTEST: /path/to/opencl/gtest USE_MKL: OFF USE_PT_TVMDSOOP: OFF MLIR_VERSION: NOT-FOUND USE_CLML: OFF USE_STACKVM_RUNTIME: OFF USE_GRAPH_EXECUTOR_CUDA_GRAPH: OFF ROCM_PATH: /opt/rocm USE_DNNL: OFF USE_VITIS_AI: OFF USE_MLIR: OFF USE_RCCL: OFF USE_LLVM: llvm-config --ignore-libllvm --link-static USE_VERILATOR: OFF USE_TF_TVMDSOOP: OFF USE_THREADS: ON USE_MSVC_MT: OFF BACKTRACE_ON_SEGFAULT: OFF USE_GRAPH_EXECUTOR: ON USE_NCCL: OFF USE_ROCBLAS: ON GIT_COMMIT_HASH: 392222e0216032509a635996b047f5c232b54402 USE_VULKAN: OFF USE_RUST_EXT: OFF USE_CUTLASS: OFF USE_CPP_RPC: OFF USE_HEXAGON: OFF USE_CUSTOM_LOGGING: OFF USE_UMA: OFF USE_FALLBACK_STL_MAP: OFF USE_SORT: ON USE_RTTI: ON GIT_COMMIT_TIME: 2023-09-11 14:30:10 +0800 USE_HEXAGON_SDK: /path/to/sdk USE_BLAS: none USE_ETHOSN: OFF USE_LIBTORCH: OFF USE_RANDOM: ON USE_CUDA: OFF USE_COREML: OFF USE_AMX: OFF BUILD_STATIC_RUNTIME: OFF USE_CMSISNN: OFF USE_KHRONOS_SPIRV: OFF USE_CLML_GRAPH_EXECUTOR: OFF USE_TFLITE: OFF USE_HEXAGON_GTEST: /path/to/hexagon/gtest PICOJSON_PATH: 3rdparty/picojson USE_OPENCL_ENABLE_HOST_PTR: OFF INSTALL_DEV: OFF USE_PROFILER: ON USE_NNPACK: OFF LLVM_VERSION: 16.0.6 USE_OPENCL: OFF COMPILER_RT_PATH: 3rdparty/compiler-rt RANG_PATH: 3rdparty/rang/include USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF USE_OPENMP: none USE_BNNS: OFF USE_CUBLAS: OFF USE_METAL: OFF USE_MICRO_STANDALONE_RUNTIME: OFF USE_HEXAGON_EXTERNAL_LIBS: OFF USE_ALTERNATIVE_LINKER: AUTO USE_BYODT_POSIT: OFF USE_HEXAGON_RPC: OFF USE_MICRO: OFF DMLC_PATH: 3rdparty/dmlc-core/include INDEX_DEFAULT_I64: ON USE_RELAY_DEBUG: OFF USE_RPC: ON USE_TENSORFLOW_PATH: none TVM_CLML_VERSION: USE_MIOPEN: OFF USE_ROCM: ON USE_PAPI: OFF USE_CURAND: OFF TVM_CXX_COMPILER_PATH: /opt/rocm/bin/hipcc HIDE_PRIVATE_SYMBOLS: ON

Hzfengsy commented 1 year ago

It's may due to float precision. For fp16, settings of atol = rtol = 1e-3 or 1e-2 is enough

tiandi111 commented 1 year ago

It seems that the precision error is huge at some elements. I compared 00055_fused_NT_matmul5_cast2_arg2.npy files and get the following output:

Traceback (most recent call last):
  File "/home/mlc-llm/benchmark/acc_check.py", line 48, in <module>
    check_single(
  File "/home/mlc-llm/benchmark/acc_check.py", line 43, in check_single
    np.testing.assert_allclose(arr1, arr2, rtol=1e-2, atol=1e-2, verbose=True)
  File "/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.01, atol=0.01

Mismatched elements: 129 / 32000 (0.403%)
Max absolute difference: 3.0947266
Max relative difference: 1527.6666
 x: array([[[ 0.031006,  0.15625 ,  0.862305, ..., -1.228516, -0.04248 ,
         -1.756836]]], dtype=float32)
 y: array([[[ 0.031006,  0.15625 ,  0.862305, ..., -1.228516, -0.04248 ,
         -1.756836]]], dtype=float32)

Also, the mismatched arrays is always the same, hence, I suspect that there's bug in generated hip code.

junrushao commented 1 year ago

0.403% mismatch usually doesn't bother me a lot, but the abs/relative diff looks a bit scary.

Hzfengsy commented 1 year ago

Could you try atol = rtol = 1e-3? the abs/relative diff may cone from different elements.

junrushao commented 11 months ago

Our ROCm backend has been rapidly evolving and stabilized since mid-Sept, and now it supports both single- and multi-GPU inference. With on-device compilation, it’s now supporting broader AMD devices with different GFX without having to rely on prebuilt models. More info: https://github.com/mlc-ai/llm-perf-bench#mlc-llm.

Let me know if it works on your end

tiandi111 commented 11 months ago

Our ROCm backend has been rapidly evolving and stabilized since mid-Sept, and now it supports both single- and multi-GPU inference. With on-device compilation, it’s now supporting broader AMD devices with different GFX without having to rely on prebuilt models. More info: https://github.com/mlc-ai/llm-perf-bench#mlc-llm.

Let me know if it works on your end

Hi junru, thanks for your reply. Sorry for not updating status from our side. This bug is actually solved a few weeks ago after we syncing our local tvm repo with the github one.

Also, we've successfully ran llama13B, 30B and 65B with the support of tvm and mlc repo. Multi-gpu inference support is awesome. I cannot wait to update more information as well as make more contribution to the community once the confidentiality period passed.

Highly appreciate efforts made by TVM and MLC teams.