google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
289 stars 40 forks source link

Approximations of non-polynomial functions with polynomials #266

Open j2kun opened 9 months ago

j2kun commented 9 months ago

Reading through https://arxiv.org/abs/2311.08610 and I wanted to highlight some of the polynomial approximation techniques that we could incorporate into HEIR when the time is right:

2.3 Polynomial Approximation Polynomial networks are commonly obtained by approximating the non-polynomial functions of pretrained networks, e.g., (Lee et al., 2022d; Takabi et al., 2019; Mohassel and Zhang, 2017; Hesamifard et al., 2017; Lee et al., 2021), or by substituting ReLU during or after a dedicated training process (Baruch et al., 2022, 2023). The Remez algorithm (Remez, 1934; Pachón and Trefethen, 2009; Egidi et al., 2020) is commonly used for finding the optimal polynomial approximation of a function f(x) in a certain degree within a predefined range [a, b], assuming a uniform distribution of x. Alternatively, iterative methods such as the Newton–Raphson method (Raphson, 1702) offer another polynomial approximation approach. Specifically, (Panda, 2022) focused on approximating √(1/x) in the interval [a, b], by dividing the interval into sub-intervals and approximating over each via Newton–Raphson method, before aggregating the results with another polynomial. However, the input range to the non-polynomial layers [a, b] can be extremely large, which results in poor and no practical approximations. This paper employs polynomials for ReLU, as defined in (Lee et al., 2021), for layer normalization (inverse square root) from (Panda, 2022) and for GELU using polynomials derived from the Remez algorithm.

j2kun commented 9 months ago

Looking to the remez algorithm, if we want a port/implementation we will need four components as dependencies:

j2kun commented 9 months ago

Reference implementation: https://github.com/samhocevar/lolremez

j2kun commented 9 months ago

Maks also pointed me to an LLVM-internal ConstraintSystem solver https://github.com/llvm/llvm-project/blob/2c875719c841ff13b9b250e6ea97fc3e0aca2070/llvm/include/llvm/Analysis/ConstraintSystem.h#L4

This isn't quite a linear system solver, but may come in handy.