ridgerchu / matmulfreellm

Implementation for MatMul-free LM.
Apache License 2.0
2.89k stars 178 forks source link

why FusedBitLinear.forward() use F.linear() with float16 inputs? #19

Open AACengineer opened 3 months ago

AACengineer commented 3 months ago

import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import mmfreelm from transformers import AutoModelForCausalLM, AutoTokenizer name = '/mnt/workspace/MMfreeLM-370M' tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained(name).cuda().half() input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda() outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6) print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) ![Uploading 捕获.JPG…]()

"The FusedBitLinear.forward() function calls the LayerNormLinearQuantFn.forward() function. Why are both x and w in the F.linear() function float16? Shouldn't x be int8 and w be within the set {-1, 0, 1}?"

ridgerchu commented 3 months ago

Hi, this is due to the consideration of speed. We found that the bf16 will get the fastest speed when we try to doing such operations, so we keep this. If you take a look about its inner values, you will find the activation is INT8 and weight is ternary. This operation is so-called fake quantization, using high precision data type but it actually has tailed to the low precision.

AACengineer commented 3 months ago

As you mentioned you will find the activation is INT8 and weight is ternary ,both inputs to F.linear() are quantized float16 types. However, F.linear() still involves multiplication operations, which is not entirely consistent with the concept of being matmul-free.Is it possible to implement the functionality of F.linear() using only add/sub and other operators in a GPU environment?

ridgerchu commented 3 months ago

Yes, for training, using matmul is the most efficient approach, and matmul-free can be seen as a special case of matmul. Therefore, we still use F.linear here. To the best of my knowledge, it is a little bit hard to leverage matmul-free operations in a GPU environment.