unslothai / unsloth

Finetune Llama 3.1, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
15.05k stars 1.01k forks source link

Is fast rope exactly equivalent to llama's apply_rotary_pos_emb? #775

Open starstream opened 1 month ago

starstream commented 1 month ago

Is fast rope exactly equivalent to llama's apply_rotary_pos_emb? I constructed a test case and found that the result is not exactly equivalent. Is there anything wrong with my case

code:


BS = 2
seq_length = 4
head_num_q = 2
head_num_k = head_num_q
head_dims = 4

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

query_states = torch.rand(BS, head_num_q, seq_length, head_dims, device=0)
key_states = torch.rand(BS, head_num_k, seq_length, head_dims, device=0)
cos = torch.rand(seq_length, head_dims, device=0)
sin = torch.rand(seq_length, head_dims, device=0)
position_ids = torch.arange(0, seq_length, device=0)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

q_emb, k_emb = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
from unsloth.kernels.rope_embedding import fast_rope_embedding
q_emb_fast, k_emb_fast = fast_rope_embedding(query_states, key_states, cos, sin)
print(q_emb)
print(q_emb_fast)

output:

tensor([[[[-0.4640, -0.2159,  1.2940,  0.5061],
          [ 0.2510,  0.6659,  0.7157,  0.7676],
          [-0.3357, -0.1534,  1.0423,  0.2901],
          [-0.0551,  0.4544,  0.3271,  0.4292]],

         [[ 0.1047, -0.0284,  0.7272,  0.6612],
          [ 0.1477,  0.5076,  0.7133,  0.6428],
          [ 0.2858,  0.3088,  0.1674,  0.5372],
          [ 0.5481,  0.7310,  0.1625,  0.6781]]],

        [[[-0.1462,  0.1645,  0.4502,  0.0084],
          [ 0.1149,  0.6139,  0.4423,  0.7918],
          [-0.6824,  0.1163,  0.8660,  0.2474],
          [-0.0140,  0.4709,  0.1222,  0.4574]],

         [[-0.6453,  0.1273,  0.5891,  0.0220],
          [-0.3913,  0.2752,  0.2875,  0.6957],
          [ 0.2637,  0.5442,  0.7180,  0.6723],
          [ 0.6071,  0.3514,  0.2053,  0.3471]]]], device='cuda:0')

tensor([[[[-0.4640, -0.2159,  0.8979,  0.2751],
          [ 0.2510,  0.6659,  0.8333,  0.5259],
          [-0.3357, -0.1534,  1.1098,  0.5500], 
          [-0.0551,  0.4544,  0.5717,  0.2453]],

         [[ 0.1047, -0.0284,  0.6065,  0.5705],
          [ 0.1477,  0.5076,  0.8688,  0.4856],
          [ 0.2858,  0.3088,  0.3744,  0.3917],  
          [ 0.5481,  0.7310,  0.3588,  0.3084]]],

        [[[-0.1462,  0.1645,  0.3166,  0.1445],
          [ 0.1149,  0.6139,  0.5300,  0.6084],
          [-0.6824,  0.1163,  0.6892,  0.2078], 
          [-0.0140,  0.4709,  0.2145,  0.3419]],

         [[-0.6453,  0.1273,  0.2877,  0.1257],
          [-0.3913,  0.2752,  0.5189,  0.7731],
          [ 0.2637,  0.5442,  1.0501,  0.3234],                                                        
          [ 0.6071,  0.3514,  0.4423,  0.2948]]]], device='cuda:0')
starstream commented 1 month ago

I change Q = Q.view(batchseq_len, n_headshead_dim) --> Q = Q.reshape(batchseq_len, n_headshead_dim)

beacuse of error: File "/opt/conda/lib/python3.10/site-packages/unsloth/kernels/rope_embedding.py", line 79, in forward
Q = Q.view(batchseq_len, n_headshead_dim)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Is that the reason? but I think the two approaches are equivalent.

danielhanchen commented 1 month ago

It should be equivalent - the issue is

cos = torch.rand(seq_length, head_dims, device=0)
sin = torch.rand(seq_length, head_dims, device=0)

is not actually correct - the cos and sin matrices are actually duplicated, hence you'll get different numbers if you naively call them.

One way to check them is to manually edit Unsloth's code during the application of the RoPE kernel, and call the slow and fast methods together, then use torch.dist and print the difference out