google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
238 stars 34 forks source link

Implement initial ReLU/sign function via polynomial approximation #658

Open j2kun opened 2 months ago

j2kun commented 2 months ago

Given that ReLU(x) = x (0.5 + 0.5 sgn(x)), this reduces to approximating the sign function, and this paper appears to have the state of the art: https://eprint.iacr.org/2020/834

Also note

j2kun commented 2 months ago

Also cf. https://openreview.net/pdf?id=Hq16Jk2bVlp and https://eprint.iacr.org/2021/1688 which use these approximations.

j2kun commented 2 months ago

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

j2kun commented 2 months ago

The paper https://eprint.iacr.org/2019/1234 is a precursor to https://eprint.iacr.org/2020/834, but also seems to explain more of the motivation behind the composite polynomials.

j2kun commented 2 months ago

An example of generating a well-fitting polynomial using lolremez: https://github.com/samhocevar/lolremez/issues/28#issuecomment-1913324892

Another tool: https://github.com/pychebfun/pychebfun

j2kun commented 2 months ago

Outline sketch:

Various improvements based on more recent research that would be worth splitting into separate tickets.

j2kun commented 2 months ago

Thanks to Seonhong Min for sending me https://eprint.iacr.org/2018/462, in which it shows that BFV achieves polynomial approximations via a fixed-point approximation, not a floating point one. I think there is also some complexity there in that evaluating a fixed point polynomial approximation also requires a rounding step, but not always, see sec 2.5

image

Maokami commented 2 months ago

I also had an interest in this issue a while back, so I know a paper worth sharing: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10155408 It's a follow-up paper by the authors of https://eprint.iacr.org/2020/834, and there's an implementation: https://github.com/snu-ccl/approxCNN/tree/main. In the repo, you'll find that they've hardcoded the coefficients of polynomial approximations of sign function from alpha=4 to 14!

Maokami commented 2 months ago

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

j2kun commented 2 months ago

After starting an implementation in https://github.com/google/heir/pull/665 (with a fixed approximation polynomial) and discussing in the HEIR meeting today, I have a few kinks to work out:

  1. I want to separate the task of choosing a polynomial approximation from the optimizations around evaluating it. This implies:
    1. I need a floating-point representation of a polynomial in the IR, but PolynomialAttr currently only supports integer coefficients
    2. I need a new op, say poly_ext.eval whose operands are the polynomial to apply and its input
    3. The approximation itself is for sign, but these operations are actually applied to max(0, x) = (x + x * sign(x)) / 2, which means we should support some internal polynomial arithmetic to construct these from the checked-in approximation. We meant to do this to support constant folding in the polynomial dialect, but never got around to it.
  2. (1) has a twist in that many more advanced polynomial approximations are not represented literally, but implicitly as a composition of smaller degree polynomials. This implies I will need a polynomial.compose op, or else an attribute that supports composite sub-polynomials, and cannot limit (1.ii) above to a single static polynomial. I think I will start with a single static polynomial but try to avoid making it difficult to upgrade to a composite polynomial.
  3. The approximate polynomial itself has a few quirks, because its coefficients further need to be encoded in a scheme-specific fashion. For CKKS this is relatively straightforward, but introduces additional error. For BGV/BFV this seems much harder, in part because the encodings are fixed-point and hence require rounding during polynomial evaluation, but rounding itself is hard (see above). There is also a question about which basis the polynomial is expressed in, cf. https://discord.com/channels/901152454077452399/1235349479482196049 for more on this
  4. The above points expose a problem with "lowering a ReLU": at the tosa level we don't yet know what scheme will be chosen, so the choice of polynomial approximation can't be scheme-aware or encoding-aware. I think the right solution here will be to include some extra metadata on the polynomial to express what function is being approximated, so that we can re-approximate it at lower levels if necessary.
j2kun commented 2 months ago

These folks do something slightly different, which is more holistic in re NN training: https://github.com/EfficientFHE/SmartPAF

They pick small degree polynomial approximations and then do a variety of network fine-tuning to adapt the network to the replaced operations. This seems out of scope of the compiler, since it would require training data to be included.

j2kun commented 1 month ago

I added an upstream RFC for the polynomial approximation pass https://discourse.llvm.org/t/rfc-a-polynomial-approximation-pass/79301