eole-nlp / eole

Open language modeling toolkit based on PyTorch
https://eole-nlp.github.io/eole
MIT License
24 stars 6 forks source link

MHA refac: rope without complex operations + query only as input of the forward #20

Closed vince62s closed 1 month ago

vince62s commented 1 month ago

Regarding the rope refac posting here the logic for legacy:

import torch

dim = 16
base = 10000
maxseqlen = 2048

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]

cos_emb = torch.cos(rope)
sin_emb = torch.sin(rope)

rope1 = torch.polar(torch.ones_like(rope), rope)
rope1 = torch.cat((rope1, rope1), dim=1)
print(rope1.size()) # [2048, 16]
start_pos = 4

query = torch.randn(8, 4, 7, 16)

#OLD CODE from llama logic (using complex operators)
query_ = query.float().reshape(8, 4, 7, -1, 2) # [8, 4, 7, 8, 2]
print(query_.size(), query_)
query_ = torch.view_as_complex(query_)
print(query_.size(), query_) # [8, 4, 7, 8] but each is a complex a + bj
print(rope1.size())
rope1 = rope1[start_pos:start_pos+query_.size(2), :rope1.size(1) //2].view(1, 1, query_.size(2), query_.size(3))
print(rope1.size()) # [1, 1, 7, 8]
print(query_ * rope1)
query_out = torch.view_as_real(query_ * rope1).flatten(3)
print(query_out.size(), query_out) # [8, 4, 7, 16]

# same maths but with cos/sin only
query_interleaved = query.reshape(query.shape[0], query.shape[1], query.shape[2], -1, 2)
print(query_interleaved.size())
cos_pos = cos_emb[start_pos:start_pos + query_interleaved.size(2)]
sin_pos = sin_emb[start_pos:start_pos + query_interleaved.size(2)]
print(cos_pos.size(), sin_pos.size())

# Apply rotary embeddings using cosine and sine functions
q_embed_cos = query_interleaved[..., 0] * cos_pos - query_interleaved[..., 1] * sin_pos
q_embed_sin = query_interleaved[..., 0] * sin_pos + query_interleaved[..., 1] * cos_pos

# Combine cosine and sine embeddings
q_embed = torch.stack((q_embed_cos, q_embed_sin), dim=-1)

# Flatten and reshape the output
query_out1 = q_embed.flatten(3)
print(query_out1.size(), query_out1)

print(torch.allclose(query_out, query_out1))