IBM / triton-dejavu

Framework to reduce autotune overhead to zero for well known deployments.
Apache License 2.0
19 stars 6 forks source link

Cached Autotuned Result Incosistency #1

Open KeremTurgutlu opened 1 month ago

KeremTurgutlu commented 1 month ago

I am getting different output results than expected when using triton_dejavu.autotune.

To reproduce you can pip install : https://github.com/mobiusml/hqq.git and build locally: https://github.com/AnswerDotAI/gemlite/tree/dejavu.

For example if there is no cache (cache.json) exists already the below python script runs fine, but after caching it fails. Even the outputs from 2 consequent kernel launches are different when cache is used, which are denoted as gemlite_output1 and gemlite_output2 below.

import torch
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer
from gemlite.core import DType, GemLiteLinear

size = (2048,2048)
bitblas_dtype = torch.float16
GROUPSIZE = 128
NBITS = 4

quant_config = BaseQuantizeConfig(nbits=NBITS,
                                    group_size=GROUPSIZE, 
                                    quant_zero=False,
                                    quant_scale=False,
                                    offload_meta=False,
                                    view_as_float=False, 
                                    axis=1)

m = torch.nn.Linear(*size, bias=False, dtype=torch.bfloat16)
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.bfloat16)
W_est = hqq_linear.dequantize().clone()#.to(bitblas_dtype)
W_q_unpacked = Quantizer.unpack[hqq_linear.meta['packing']](hqq_linear.W_q)
scale, zero, shape = hqq_linear.meta['scale'], hqq_linear.meta['zero'], hqq_linear.meta['shape']
scale = scale.to(bitblas_dtype)
zero = zero.to(bitblas_dtype)

gemlite_linear = GemLiteLinear(
    NBITS, #supported: [8, 4, 2, 1]
    group_size=GROUPSIZE, # any group_size divisible by 32
    in_features=size[0], # input size
    out_features=size[1], #ouput size
    input_dtype=DType.FP16, #FP16 or BF16
    output_dtype=DType.FP16, #FP16 or BF16
    acc_dtype=DType.FP16, #FP16 or FP32 
)

#Packing: we follow the same format as hqq (https://github.com/mobiusml/hqq/)
gemlite_linear.pack(W_q_unpacked, scale, zero)

# set torch seed
torch.manual_seed(42)
x = torch.randn(2, size[0]).to(bitblas_dtype).cuda(); x
hqq_output = x.to(torch.bfloat16) @ W_est.t()

gemlite_output1 = gemlite_linear.forward_manual(x, matmul_type="GEMV")
try:
    assert torch.allclose(gemlite_output1, hqq_output.to(bitblas_dtype), atol=1e-2, rtol=1e-2)
except:
    print("Original vs Triton are different")
    print("Actual: ", gemlite_output1)
    print("Expected: ", hqq_output)
    print()

gemlite_output2 = gemlite_linear.forward_manual(x, matmul_type="GEMV")
try:
    assert torch.allclose(gemlite_output2, hqq_output.to(bitblas_dtype), atol=1e-2, rtol=1e-2)
except:
    print("Original vs Triton are different")
    print("Actual: ", gemlite_output2)
    print("Expected: ", hqq_output)
    print()

try: 
    assert torch.allclose(gemlite_output1, gemlite_output2, atol=1e-2, rtol=1e-2)
except:
    print("Dejavu outputs are different")
    print(gemlite_output1)
    print(gemlite_output2)
    print()
{
  "signature": "JITFunction(gemlite.triton_kernels.gemv_A16fWnO16f_int32packing:gemv_A16fWnO16f_int32packing_kernel)",
  "total_bench_time_s": 33.92143940925598,
  "evaluated_configs": 16,
  "cache": {
    "('2', '2048', '2048', '128', '4', 'torch.float16', 'torch.int32', 'torch.float16', 'torch.float16', 'torch.float16')": "BLOCK_SIZE_M: 1, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None"
  },
  "timings": {
    "('2', '2048', '2048', '128', '4', 'torch.float16', 'torch.int32', 'torch.float16', 'torch.float16', 'torch.float16')": {
      "values": [
        0.018432000651955605,
        0.018432000651955605,
        0.01945599913597107
      ],
      "lables": [
        "ms",
        "min_ms",
        "max_ms"
      ],
      "rep_t_ms": 100,
      "warmup_t_ms": 200
    }
  }
}
bringlein commented 1 month ago

Hi Kerem,

thanks for reporting this, I'll look into it.

But it would help me if you could post the output of your script executed with TRITON_DEJAVU_DEBUG=1? Also, does it work as you expect it, if you use the normal triton autotuning?

KeremTurgutlu commented 1 month ago

@bringlein thanks for looking into it! Here are the outputs of 2 consecutive runs of the script, first time there is no cache and tests pass, the second time it uses the cache but fails.

root@# python debug_dejavu.py 
[Triton Dejavu:WARNING] use of 'prune_configs_by' could influence the autotuner decision in a way not visible to triton-dejavu. Please ensure that configs could be reused.

root@# TRITON_DEJAVU_DEBUG=1 python debug_dejavu.py 
[Triton Dejavu:WARNING] use of 'prune_configs_by' could influence the autotuner decision in a way not visible to triton-dejavu. Please ensure that configs could be reused.
[triton-dejavu] restored 1 configurations for gemv_A16fWnO16f_int32packing_kernel/autotune_config-d024ac9abd3a8098c64c928822f4a703adddd8b7084df0769a5844813a43ca59/kernel_configs-abeac919de3416fafac42ccb8ee5ae745764eca2e3ebf61898e0391df4129068/code_version-deea26f1e06c1d69417f61b1c2778feb4cad86408585b8ef32a0a7e7ffa8be47/default.
Original vs Triton are different
Actual:  tensor([[ 2.0586,  2.3301,  1.8613,  ..., -1.2754, -1.3750, -1.2500],
        [-3.3027, -0.5957,  0.9829,  ..., -1.7705,  1.0059, -0.9121]],
       device='cuda:0', dtype=torch.float16)
Expected:  tensor([[ 0.0737,  0.3887,  0.0112,  ..., -0.0162,  0.4570,  0.4238],
        [-1.5234,  0.7617, -0.6680,  ..., -0.2578, -0.8555,  0.1533]],
       device='cuda:0', dtype=torch.bfloat16)

Original vs Triton are different
Actual:  tensor([[ 0.0723,  1.7871,  0.0120,  ...,  1.7715,  0.4561, -1.0928],
        [-1.5264, -1.0244, -0.6675,  ...,  1.4697, -0.8569,  1.8643]],
       device='cuda:0', dtype=torch.float16)
Expected:  tensor([[ 0.0737,  0.3887,  0.0112,  ..., -0.0162,  0.4570,  0.4238],
        [-1.5234,  0.7617, -0.6680,  ..., -0.2578, -0.8555,  0.1533]],
       device='cuda:0', dtype=torch.bfloat16)

Dejavu outputs are different
tensor([[ 2.0586,  2.3301,  1.8613,  ..., -1.2754, -1.3750, -1.2500],
        [-3.3027, -0.5957,  0.9829,  ..., -1.7705,  1.0059, -0.9121]],
       device='cuda:0', dtype=torch.float16)
tensor([[ 0.0723,  1.7871,  0.0120,  ...,  1.7715,  0.4561, -1.0928],
        [-1.5264, -1.0244, -0.6675,  ...,  1.4697, -0.8569,  1.8643]],
       device='cuda:0', dtype=torch.float16)

Also, does it work as you expect it, if you use the normal triton autotuning?

Yes, regular autotuning works.

bringlein commented 1 month ago

Hi @KeremTurgutlu, sorry I only had the time today to look into it again.

Sadly, I can't reproduce your issue, because the local build of gemlite always fails with:

File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1985, in _get_cuda_arch_flags
      arch_list[-1] += '+PTX'
IndexError: list index out of range

(even using torch nightly, etc...building based on nvidia/cuda:12.4.1-devel-ubuntu22.04) Could you maybe provide your Dockerfile?

However, I tried to analyze your code and I think I found the cause: You define pre_hook=init_to_zero("c_ptr") for every config gemv_A16fWnO16f_int32packing.py#L54 individually. Right now, triton-dejavu can only restore pre/post hooks if there are defined at autotuner level, using the optional arguments in autotuner.py#L323-L331.

So, if this is the cause, then you could avoid the problem by adding pre_hook=init_to_zero("c_ptr") to your autotuner definition in line 65 (it is always the same function anyway in your case). Can you confirm this?

I will add support for individual config pre hooks with the next version of triton-dejavu in the comming weeks.

KeremTurgutlu commented 1 month ago

@bringlein thanks for looking into this! I tried the following changes: https://github.com/mobiusml/gemlite/commit/1bedc1ba2a4f9f9c24996d613e2fc1654ccd79a9

This time when I run the following code with a cached config it works sometimes and doesn't - still there is some inconsistency but I think the changes are in the right direction because this issue doesn't happen with the GEMM kernel which doesn't have a pre_hook.

bringlein commented 4 weeks ago

Good to hear!

But IMHO, these remaining inconsistencies are then most likely due to the observation that the output array of your triton kernel is not written completely by the kernel (i.e. it partially depends on what is there before)? Hence, depending on what is in the array before, the result is different...and autotuning causes many iterations of the kernel to be executed and therefore, there is a higher chance of the "correct" result being in the output array? Which would then be unrelated to the autotune caching.

bringlein commented 3 weeks ago

Hi @KeremTurgutlu, as promissed, I implemented the support for individual pre_hooks last week: https://github.com/IBM/triton-dejavu/commit/d5254f94a825779372944a4ac11e82af2374e9bf

Maybe you could give it a try.