Closed vince62s closed 1 month ago
Regarding the rope refac posting here the logic for legacy:
import torch dim = 16 base = 10000 maxseqlen = 2048 def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) tmax = torch.arange(maxseqlen, device=inv_freq.device) rope = torch.outer(tmax, inv_freq).float() # rope is now matrix [maxseqlen, dim/2] cos_emb = torch.cos(rope) sin_emb = torch.sin(rope) rope1 = torch.polar(torch.ones_like(rope), rope) rope1 = torch.cat((rope1, rope1), dim=1) print(rope1.size()) # [2048, 16] start_pos = 4 query = torch.randn(8, 4, 7, 16) #OLD CODE from llama logic (using complex operators) query_ = query.float().reshape(8, 4, 7, -1, 2) # [8, 4, 7, 8, 2] print(query_.size(), query_) query_ = torch.view_as_complex(query_) print(query_.size(), query_) # [8, 4, 7, 8] but each is a complex a + bj print(rope1.size()) rope1 = rope1[start_pos:start_pos+query_.size(2), :rope1.size(1) //2].view(1, 1, query_.size(2), query_.size(3)) print(rope1.size()) # [1, 1, 7, 8] print(query_ * rope1) query_out = torch.view_as_real(query_ * rope1).flatten(3) print(query_out.size(), query_out) # [8, 4, 7, 16] # same maths but with cos/sin only query_interleaved = query.reshape(query.shape[0], query.shape[1], query.shape[2], -1, 2) print(query_interleaved.size()) cos_pos = cos_emb[start_pos:start_pos + query_interleaved.size(2)] sin_pos = sin_emb[start_pos:start_pos + query_interleaved.size(2)] print(cos_pos.size(), sin_pos.size()) # Apply rotary embeddings using cosine and sine functions q_embed_cos = query_interleaved[..., 0] * cos_pos - query_interleaved[..., 1] * sin_pos q_embed_sin = query_interleaved[..., 0] * sin_pos + query_interleaved[..., 1] * cos_pos # Combine cosine and sine embeddings q_embed = torch.stack((q_embed_cos, q_embed_sin), dim=-1) # Flatten and reshape the output query_out1 = q_embed.flatten(3) print(query_out1.size(), query_out1) print(torch.allclose(query_out, query_out1))
Regarding the rope refac posting here the logic for legacy: