google / heir

A compiler for homomorphic encryption
Apache License 2.0
238 stars 34 forks source link

Implement initial ReLU/sign function via polynomial approximation #658

Open j2kun opened 2 months ago

j2kun commented 2 months ago

Given that ReLU(x) = x (0.5 + 0.5 sgn(x)), this reduces to approximating the sign function, and this paper appears to have the state of the art:

Also note

j2kun commented 2 months ago

Also cf. and which use these approximations.

j2kun commented 2 months ago

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo:

j2kun commented 2 months ago

The paper is a precursor to, but also seems to explain more of the motivation behind the composite polynomials.

j2kun commented 2 months ago

An example of generating a well-fitting polynomial using lolremez:

Another tool:

j2kun commented 2 months ago

Outline sketch:

Various improvements based on more recent research that would be worth splitting into separate tickets.

j2kun commented 2 months ago

Thanks to Seonhong Min for sending me, in which it shows that BFV achieves polynomial approximations via a fixed-point approximation, not a floating point one. I think there is also some complexity there in that evaluating a fixed point polynomial approximation also requires a rounding step, but not always, see sec 2.5


Maokami commented 2 months ago

I also had an interest in this issue a while back, so I know a paper worth sharing: It's a follow-up paper by the authors of, and there's an implementation: In the repo, you'll find that they've hardcoded the coefficients of polynomial approximations of sign function from alpha=4 to 14!

Maokami commented 2 months ago

The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo:

Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.

j2kun commented 2 months ago

After starting an implementation in (with a fixed approximation polynomial) and discussing in the HEIR meeting today, I have a few kinks to work out:

  1. I want to separate the task of choosing a polynomial approximation from the optimizations around evaluating it. This implies:
    1. I need a floating-point representation of a polynomial in the IR, but PolynomialAttr currently only supports integer coefficients
    2. I need a new op, say poly_ext.eval whose operands are the polynomial to apply and its input
    3. The approximation itself is for sign, but these operations are actually applied to max(0, x) = (x + x * sign(x)) / 2, which means we should support some internal polynomial arithmetic to construct these from the checked-in approximation. We meant to do this to support constant folding in the polynomial dialect, but never got around to it.
  2. (1) has a twist in that many more advanced polynomial approximations are not represented literally, but implicitly as a composition of smaller degree polynomials. This implies I will need a polynomial.compose op, or else an attribute that supports composite sub-polynomials, and cannot limit (1.ii) above to a single static polynomial. I think I will start with a single static polynomial but try to avoid making it difficult to upgrade to a composite polynomial.
  3. The approximate polynomial itself has a few quirks, because its coefficients further need to be encoded in a scheme-specific fashion. For CKKS this is relatively straightforward, but introduces additional error. For BGV/BFV this seems much harder, in part because the encodings are fixed-point and hence require rounding during polynomial evaluation, but rounding itself is hard (see above). There is also a question about which basis the polynomial is expressed in, cf. for more on this
  4. The above points expose a problem with "lowering a ReLU": at the tosa level we don't yet know what scheme will be chosen, so the choice of polynomial approximation can't be scheme-aware or encoding-aware. I think the right solution here will be to include some extra metadata on the polynomial to express what function is being approximated, so that we can re-approximate it at lower levels if necessary.
j2kun commented 2 months ago

These folks do something slightly different, which is more holistic in re NN training:

They pick small degree polynomial approximations and then do a variety of network fine-tuning to adapt the network to the replaced operations. This seems out of scope of the compiler, since it would require training data to be included.

j2kun commented 1 month ago

I added an upstream RFC for the polynomial approximation pass