Open j2kun opened 2 months ago
Also cf. https://openreview.net/pdf?id=Hq16Jk2bVlp and https://eprint.iacr.org/2021/1688 which use these approximations.
The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124
The paper https://eprint.iacr.org/2019/1234 is a precursor to https://eprint.iacr.org/2020/834, but also seems to explain more of the motivation behind the composite polynomials.
An example of generating a well-fitting polynomial using lolremez: https://github.com/samhocevar/lolremez/issues/28#issuecomment-1913324892
Another tool: https://github.com/pychebfun/pychebfun
Outline sketch:
Various improvements based on more recent research that would be worth splitting into separate tickets.
Thanks to Seonhong Min for sending me https://eprint.iacr.org/2018/462, in which it shows that BFV achieves polynomial approximations via a fixed-point approximation, not a floating point one. I think there is also some complexity there in that evaluating a fixed point polynomial approximation also requires a rounding step, but not always, see sec 2.5
I also had an interest in this issue a while back, so I know a paper worth sharing: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10155408 It's a follow-up paper by the authors of https://eprint.iacr.org/2020/834, and there's an implementation: https://github.com/snu-ccl/approxCNN/tree/main. In the repo, you'll find that they've hardcoded the coefficients of polynomial approximations of sign function from alpha=4 to 14!
The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124
Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.
After starting an implementation in https://github.com/google/heir/pull/665 (with a fixed approximation polynomial) and discussing in the HEIR meeting today, I have a few kinks to work out:
poly_ext.eval
whose operands are the polynomial to apply and its inputmax(0, x) = (x + x * sign(x)) / 2
, which means we should support some internal polynomial arithmetic to construct these from the checked-in approximation. We meant to do this to support constant folding in the polynomial dialect, but never got around to it.polynomial.compose
op, or else an attribute that supports composite sub-polynomials, and cannot limit (1.ii) above to a single static polynomial. I think I will start with a single static polynomial but try to avoid making it difficult to upgrade to a composite polynomial.These folks do something slightly different, which is more holistic in re NN training: https://github.com/EfficientFHE/SmartPAF
They pick small degree polynomial approximations and then do a variety of network fine-tuning to adapt the network to the replaced operations. This seems out of scope of the compiler, since it would require training data to be included.
I added an upstream RFC for the polynomial approximation pass https://discourse.llvm.org/t/rfc-a-polynomial-approximation-pass/79301
Given that
ReLU(x) = x (0.5 + 0.5 sgn(x))
, this reduces to approximating the sign function, and this paper appears to have the state of the art: https://eprint.iacr.org/2020/834Also note
max(u, v) = ((u+v) + (u-v)sign(u-v)) / 2
min(u, v) = -max(-u, -v) = ((u+v) - (v - u)sign(v - u)) / 2