Open nighting0le01 opened 4 months ago
FWIW: In some of our internal experiments we realized that computing the RoPE sinusoidals wasn't very optimal (especially using complex numbers which the stack decomposes into tensors anyway), so we decided to pre-compute them in the torch Module's __init__
:
# shape: [max_sequence_len, head_dim]
self.register_buffer("cos_cached", torch.tensor(np.cos(emb, dtype=np.float32)))
self.register_buffer("sin_cached", torch.tensor(np.sin(emb, dtype=np.float32)))
and then apply them as follows:
cos, sin = self.cos_cached[token_indices], self.sin_cached[token_indices]
rope_embedded = x * cos + rotate_half(x) * sin
where rotate_half is something like:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.concat((-x2, x1), dim=-1)
@junpeiz since you have some context on the complex numbers stuff, how difficult would it be to support torch.polar
for convenience?
Shouldn't be very difficult, as the complex result construction is straightforward absâ‹…cos(angle)+absâ‹…sin(angle)â‹…j
.
@nighting0le01 Thank you for filing this feature request! We will add it based on priority and bandwidth. Meanwhile, feel free to try it on your side by following how add
supports complex: https://github.com/apple/coremltools/blob/0e292a072452db19d1e64b687a372c0c54704a90/coremltools/converters/mil/frontend/torch/ops.py#L917
(More examples could be found if you search complex
in that file)
🌱 Describe your Feature Request
How can this feature be used?
In Stable diffusion and LLM with Rope