CryVeck / QuaRot

Code for Neurips24 paper: QuaRot, an end-to-end 4-bit inference of large language models.
https://arxiv.org/abs/2404.00456
Apache License 2.0
0 stars 0 forks source link

Rotation hidden size #1

Closed CryVeck closed 5 days ago

CryVeck commented 5 days ago

Current issue with the

In "rotation_utils.py":

class QKRotationWrapper(torch.nn.Module):
    # [...]

    def forward(self, *args, **kwargs):
        q, k = self.func(*args, **kwargs)
        print(f"q shape before : {q.shape}")
        print(f"k shape before : {k.shape}")
        dtype = q.dtype
        q = hadamard_transform(q.float(), scale=1/math.sqrt(q.shape[-1])).to(dtype)
        k = hadamard_transform(k.float(), scale=1/math.sqrt(k.shape[-1])).to(dtype)
        (bsz, num_heads, seq_len, head_dim) = k.shape

        print(f"q shape : {q.shape}")
        print(f"k shape : {k.shape}")

        if self.k_groupsize == -1: #token-wise 
            token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size) # <---- Error on shape here
            self.k_quantizer.find_params(token_wise_k)
            k = self.k_quantizer(token_wise_k).reshape((bsz, seq_len, num_heads, head_dim)).transpose(1, 2).to(q)
        else: #head-wise quantization
            per_head_k = k.view(-1, head_dim)
            self.k_quantizer.find_params(per_head_k)
            k = self.k_quantizer(per_head_k).reshape((bsz, num_heads, seq_len, head_dim)).to(q)

        self.k_quantizer.free()

        return q, k

self.config.hidden_size, does not correspond anymore for the q and v dimension. It's value is 3072 and k hidden size is 1024. Does not cause any bugs when the factor is right but is leading to an issue

WARNING: LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=3072, out_features=3072, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (k_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
-            (module): Linear(in_features=3072, out_features=1024, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (v_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=3072, out_features=1024, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (o_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=3072, out_features=3072, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=3072, out_features=8192, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (up_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=3072, out_features=8192, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (down_proj): ActQuantWrapper(
            Input Quantizer Bits: 16
            Output Quantizer Bits: 16
            (module): Linear(in_features=8192, out_features=3072, bias=False)
            (quantizer): ActQuantizer()
            (out_quantizer): ActQuantizer()
          )
          (act_fn): SiLU()
        )
        (input_layernorm): RMSN()
        (post_attention_layernorm): RMSN()
      )
    )
    (norm): RMSN()
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): ActQuantWrapper(
    Input Quantizer Bits: 16
    Output Quantizer Bits: 16
    (module): Linear(in_features=3072, out_features=128256, bias=False)
    (quantizer): ActQuantizer()
    (out_quantizer): ActQuantizer()
  )
)
CryVeck commented 5 days ago

Llama-3.2-1B config

LlamaConfig {                                                                                                                                                                           
  "_name_or_path": "meta-llama/Llama-3.2-1B",                                                                                                                                                                                                                                    
  "architectures": [                                                                                                                                                                    
    "LlamaForCausalLM"                                                                                                                                                                                                                                                           
  ],                                                                                                                                                                                    
  "attention_bias": false,                                                                                                                                                                                                                                                       
  "attention_dropout": 0.0,                                                                                                                                                             
  "bos_token_id": 128000,                                                                                                                                                                                                                                                        
  "eos_token_id": 128001,                                                                                                                                                               
  "head_dim": 64,                                                                                                                                                                                                                                                                
  "hidden_act": "silu",                                                                                                                                                                 
  "hidden_size": 2048,                                                                                                                                                                                                                                                           
  "initializer_range": 0.02,                                                                                                                                                            
  "intermediate_size": 8192,                                                                                                                                                                                                                                                     
  "max_position_embeddings": 131072,                                                                                                                                                    
  "mlp_bias": false,                                                                                                                                                                                                                                                             
  "model_type": "llama",                                                                                                                                                                
  "num_attention_heads": 32,                                                                                                                                                                                                                                                     
  "num_hidden_layers": 16,                                                                                                                                                              
  "num_key_value_heads": 8,                                                                                                                                                                                                                                                      
  "pretraining_tp": 1,                                                                                                                                                                  
  "rms_norm_eps": 1e-05,                                                                                                                                                                                                                                                         
  "rope_scaling": {                                                                                                                                                                     
    "factor": 32.0,                                                                                                                                                                                                                                                              
    "high_freq_factor": 4.0,                                                                                                                                                            
    "low_freq_factor": 1.0,                                                                                                                                                                                                                                                      
    "original_max_position_embeddings": 8192,                                                                                                                                           
    "rope_type": "llama3"                                                                                                                                                                                                                                                        
  },                                                                                                                                                                                    
  "rope_theta": 500000.0,                                                                                                                                                                                                                                                        
  "tie_word_embeddings": true,                                                                                                                                                          
  "torch_dtype": "bfloat16",                                                                                                                                                                                                                                                     
  "transformers_version": "4.43.1",                                                                                                                                                     
  "use_cache": true,                                                                                                                                                                                                                                                             
  "vocab_size": 128256                                                                                                                                                                  
}                                                                                                                                                                                                                                                                                

Llama3.2 3B

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-3.2-3B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.43.1",
  "use_cache": true,
  "vocab_size": 128256
}
CryVeck commented 5 days ago

The corresponding value should be

config.hidden_size * config.num_key_value_heads / config.num_attention_heads

But in the case of Llama 2, as the value of num_attention_head and num_key_value_heads are the same.