ridgerchu / matmulfreellm

Implementation for MatMul-free LM.
Apache License 2.0
2.5k stars 139 forks source link

Ternary weight values #18

Closed dmahurin closed 2 weeks ago

dmahurin commented 2 weeks ago

Looking at the weight values, we see that they are bfloat16. Further, conversion to ternary is done at run-time (in FusedBitLinear).

To see if the model still worked with ternary weights, I re-wrote model.safetensors (using script below) to pre-quantize using the weight_quant function. still storing the result in bfloat16 (also tried float16 and float32). While the model still worked when the attention projection weights were converted and separately when the mlp projection weights were converted, if both the attention weights and mlp weights were converted, the model output was much worse.

The weight_quant function converts weights to three values (scaled ternary). Thinking that the conversion was perhaps affected by precision, I tried first converting to float32. But this did not affect the result.

Do you have suggestions on pre-converting to ternary? Such pre-converting could potentially lead to internally encoding as ternary to reduce storage/ram use. Are there any plans to follow such a path?

 from safetensors import safe_open
from safetensors.torch import save_file

def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

out = {}
with safe_open("model.safetensors.orig", framework="pt") as f:
    for key in f.keys():
        w = f.get_tensor(key)
        if ( "attn" in key and "proj" in key ) and "norm" not in key:
            w = weight_quant(w)
        out[key] = w
save_file(out, 'model.safetensors.new', metadata={'format': 'pt'})
yzhangcs commented 2 weeks ago

@dmahurin Hi, did you disable this line after pre-quant? https://github.com/ridgerchu/matmulfreellm/blob/master/mmfreelm/modules/layernorm.py#L670

dmahurin commented 2 weeks ago

@yzhangcs,

I had previously commented that out, but as I was not quantizing all weights and assumed that assumed that the weight_quant function would end with the same value when ran again, I left it uncommented.

My assumption was not correct... The values change if re-quantized.

Instead of commenting, this change works, and skips quantize if already done:

-        linear_weight = weight_quant(linear_weight).to(dtype)
+        if 3 != linear_weight.unique().numel():
+            linear_weight = weight_quant(linear_weight).to(dtype)

Thanks @yzhangcs

radna0 commented 2 weeks ago

@dmahurin From what I understand, you're trying to pre-quantize the model with layers being in bf16 to ternary values because it would reduce RAM usage much better than doing it at runtime. Would you need to convert back after training is finished, or is the model intended to remain in its quantized form for deployment?

dmahurin commented 2 weeks ago

@radna0 I wanted to first see that using pre-quantization did continue to work as expected. (it does)

Pre-quantization by itself (as done above) would have no impact on RAM or storage, as the model is still stored with bfloat16 weights on disk and in memory. To get an actual reduction in storage/RAM, the ternary values would need to be stored as some encoding of ternary (5 weights (trits) in 8 bits, or another encoding).

If such an encoding were done, then the 2.7B model could consume <600 megabytes on disk and in memory.

radna0 commented 2 weeks ago

@dmahurin Thanks for the clarification. Given that efficient encoding of ternary values is needed to reduce RAM and storage usage, could you provide more details on how this encoding can be implemented? Are there specific methods or tools you recommend for encoding ternary values (e.g., 5 weights (trits) in 8 bits)?

Additionally, what is the expected impact on model performance, both in terms of inference speed and accuracy, when using these encoded ternary values compared to the original bfloat16 weights?

Lastly, could you elaborate on the use cases where pre-quantization would be most beneficial, especially if we proceed with the efficient encoding approach?