apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.44k stars 643 forks source link

RoPE implemetation, Torch.polar op not available in MIL requesting support #2258

Open nighting0le01 opened 4 months ago

nighting0le01 commented 4 months ago

🌱 Describe your Feature Request

srjoglekar246 commented 4 months ago

FWIW: In some of our internal experiments we realized that computing the RoPE sinusoidals wasn't very optimal (especially using complex numbers which the stack decomposes into tensors anyway), so we decided to pre-compute them in the torch Module's __init__:

# shape: [max_sequence_len, head_dim]
self.register_buffer("cos_cached", torch.tensor(np.cos(emb, dtype=np.float32)))
self.register_buffer("sin_cached", torch.tensor(np.sin(emb, dtype=np.float32)))

and then apply them as follows:

cos, sin = self.cos_cached[token_indices], self.sin_cached[token_indices]
rope_embedded = x * cos + rotate_half(x) * sin

where rotate_half is something like:

x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.concat((-x2, x1), dim=-1)

@junpeiz since you have some context on the complex numbers stuff, how difficult would it be to support torch.polar for convenience?

junpeiz commented 4 months ago

Shouldn't be very difficult, as the complex result construction is straightforward absâ‹…cos(angle)+absâ‹…sin(angle)â‹…j.

@nighting0le01 Thank you for filing this feature request! We will add it based on priority and bandwidth. Meanwhile, feel free to try it on your side by following how add supports complex: https://github.com/apple/coremltools/blob/0e292a072452db19d1e64b687a372c0c54704a90/coremltools/converters/mil/frontend/torch/ops.py#L917 (More examples could be found if you search complex in that file)