Open mratsim opened 1 year ago
I have a local implementation based on Constantine (unfuzzed yet, only basic testing)
Performance improvement from 53.7x to 82.9x on 256 bits inputs
bench code
import
../constantine/math/arithmetic,
../constantine/math/io/[io_bigints, io_fields],
../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
../constantine/math/config/[type_bigint, curves, precompute],
../constantine/platforms/[abstractions, codecs],
../helpers/prng_unsafe,
std/[times, monotimes, strformat]
import stint
# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args
var rng: RngState
rng.seed(1234)
for i in 0 ..< 5:
echo "\ni: ", i
# -------------------------
let M = rng.random_long01Seq(BigInt[bits])
let a = rng.random_long01Seq(BigInt[bits])
var exponent = newSeq[byte](expBits div 8)
for i in 0 ..< expBits div 8:
exponent[i] = byte rng.next()
# -------------------------
let aHex = a.toHex()
let eHex = exponent.toHex()
let mHex = M.toHex()
echo " base: ", a.toHex()
echo " exponent: ", exponent.toHex()
echo " modulus: ", M.toHex()
# -------------------------
var elapsedCtt, elapsedStint: int64
block:
var r: BigInt[bits]
let start = getMonotime()
r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
let stop = getMonotime()
elapsedCtt = inNanoseconds(stop-start)
echo " r Constantine: ", r.toHex()
echo " elapsed Constantine: ", elapsedCtt, " ns"
# -------------------------
block:
let aa = Stuint[bits].fromHex(aHex)
let ee = Stuint[expBits].fromHex(eHex)
let mm = Stuint[bits].fromHex(mHex)
var r: Stuint[bits]
let start = getMonotime()
r = powmod(aa, ee, mm)
let stop = getMonotime()
elapsedStint = inNanoseconds(stop-start)
echo " r stint: ", r.toHex()
echo " elapsed Stint: ", elapsedStint, " ns"
echo &"\n ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
echo "---------------------------------------------------------"
-d:danger
. GCC is very bad at multiprecision arithmetic.Redid the bench vs GMP, Constantine is +/- 10% slower or faster (without assembly!)
bench
import
../constantine/math/arithmetic,
../constantine/math/io/[io_bigints, io_fields],
../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
../constantine/math/config/[type_bigint, curves, precompute],
../constantine/platforms/[abstractions, codecs],
../helpers/prng_unsafe,
std/[times, monotimes, strformat]
import stint, gmp
const # https://gmplib.org/manual/Integer-Import-and-Export.html
GMP_WordLittleEndian = -1'i32
GMP_WordNativeEndian = 0'i32
GMP_WordBigEndian = 1'i32
GMP_MostSignificantWordFirst = 1'i32
GMP_LeastSignificantWordFirst = -1'i32
# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args
var rng: RngState
rng.seed(1234)
for i in 0 ..< 5:
echo "i: ", i
# -------------------------
let M = rng.random_long01Seq(BigInt[bits])
let a = rng.random_long01Seq(BigInt[bits])
var exponent = newSeq[byte](expBits div 8)
for i in 0 ..< expBits div 8:
exponent[i] = byte rng.next()
# -------------------------
let aHex = a.toHex()
let eHex = exponent.toHex()
let mHex = M.toHex()
echo " base: ", a.toHex()
echo " exponent: ", exponent.toHex()
echo " modulus: ", M.toHex()
# -------------------------
var elapsedCtt, elapsedStint, elapsedGMP: int64
block:
var r: BigInt[bits]
let start = getMonotime()
r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
let stop = getMonotime()
elapsedCtt = inNanoseconds(stop-start)
echo " r Constantine: ", r.toHex()
echo " elapsed Constantine: ", elapsedCtt, " ns"
# -------------------------
block:
let aa = Stuint[bits].fromHex(aHex)
let ee = Stuint[expBits].fromHex(eHex)
let mm = Stuint[bits].fromHex(mHex)
var r: Stuint[bits]
let start = getMonotime()
r = powmod(aa, ee, mm)
let stop = getMonotime()
elapsedStint = inNanoseconds(stop-start)
echo " r stint: ", r.toHex()
echo " elapsed Stint: ", elapsedStint, " ns"
block:
var aa, ee, mm, rr: mpz_t
mpz_init(aa)
mpz_init(ee)
mpz_init(mm)
mpz_init(rr)
aa.mpz_import(a.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, a.limbs[0].unsafeAddr)
let e = BigInt[expBits].unmarshal(exponent, bigEndian)
ee.mpz_import(e.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, e.limbs[0].unsafeAddr)
mm.mpz_import(M.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, M.limbs[0].unsafeAddr)
let start = getMonotime()
rr.mpz_powm(aa, ee, mm)
let stop = getMonotime()
elapsedGMP = inNanoSeconds(stop-start)
var r: BigInt[bits]
var rWritten: csize
discard r.limbs[0].addr.mpz_export(rWritten.addr, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, rr)
echo " r GMP: ", r.toHex()
echo " elapsed GMP: ", elapsedGMP, " ns"
echo &"\n ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
echo &" ratio GMP/Constantine: {float64(elapsedGMP)/float64(elapsedCtt):.3f}x"
echo "---------------------------------------------------------"
The goal of this issue is to provide a guideline on how to fix https://github.com/status-im/nimbus-eth1/issues/1584.
Also pinging @treeform, @guzba on how to implement fast RSA (modexp is the bottleneck) and @dlesnoff for nim/bigint
Here is a write-up on how to implement fast modular exponentiation.
Recommended textbook for implementer:
We assume 64-bit words.
Vocabulary
In textbook, you'll encounter the term "reduction" which should be read "modular reduction" and just means the remainder after Euclidean division.
The assembly myth
First of all let's dispel a myth, that assembly is the key for speed. That's true but missing the big picture.
Some benchmarks on Constantine modular exponentiation on BN254 prime field by a 254-bit integer:
So the ratio assembly/no assembly is "only" 30% with Clang. And don't use GCC with bigints, it's just bad if you don't use assembly.
The big picture
Modular exponentiation is implemented through an algorithm called double-and-add (or for exponentiation multiply-and-square), which does as many modular squarings as the number of bits in the exponent and as many modular multiplications as the number of set bits in the exponents.
For random numbers, about 50% of the bits are set. Assuming 256-bits, that's 384 modular multiplications/squarings.
Each modular multiplication is naively a multiplication 256-bit x 256-bit -> 512-bit and then modulo a 256-bit number.
The bottleneck
Let's take https://www.agner.org/optimize/instruction_tables.pdf and have a look at the speed of the
DIV
instruction, which is necessary to compute modulo.x86 started to be extremely optimized for BigInt after Broadwell which introduced ADCX and ADOX (and MULX was introduced in Haswell, Broadwell predecessor).
DIV on 64-bit input takes 36 cycles, and has a latency of up to 95 cycles (i.e. anything that depends on that result may wait up to 95 cycles before proceeding).
In comparison add and shifts are just 1 cycle. So anything that uses division starts with a heavy disadvantage.
Note: that disadvantage is still faster than doing bit-by-bit division like here (the algorithm is chosen if there is at most 8-bit of length difference between operands) https://github.com/status-im/nim-stint/blob/94fc521ee0f1e113d09ceeaa3568d4d7a6c0b67d/stint/private/uint_div.nim#L206-L230
Back of the napkin perf:
A 256-bit modular reduction will need 4 64-bit DIV (and other things). So we already have a cost of about 400 cycles. We need that on 384 operations, so 384x400 = 153600 cycles. That's over 5x more costly that my slow benchmark of GCC without assembly (well it's on 254-bit instead of 256-bit)
And there is a lot more work beside the divisions, see https://github.com/status-im/nim-stint/blob/94fc521ee0f1e113d09ceeaa3568d4d7a6c0b67d/stint/private/uint_div.nim#L131-L169
i.e. all the example implementations on Wikipedia are really slow: https://en.wikipedia.org/wiki/Modular_arithmetic#Example_implementations
How to avoid division
There are 2 main techniques to avoid costly divisions:
Barret reduction
https://en.wikipedia.org/wiki/Barrett_reduction
Instead of doing
a*b mod m
you doa*b*(2⁶⁴)ᵏ/m
and then you shift by k words (i.e. divide by 2⁶⁴) This is called Barret reduction and is interesting when (2⁶⁴)ᵏ/m can be reused many times. k is chosen so that the division by m has an inconsequential rounding error.Montgomery reduction
Montgomery reduction uses a similar approach to Barret reduction but with lower complexity at the price of the need to "transport"/convert the number being reduced to the "Montgomery domain"
In practice we do all computation on
a' = aR (mod m)
withR = (2⁶⁴)^(numWords) (mod m)
, numWords = 4 for 256-bit numbers.Once we have numbers in the Montgomery domain, there is an operation called Montgomery modular multiplication (montMul) that does:
montMul(a', b') = montMul(aR, bR) = abR (mod m)
montMul
is the fastest modular multiplication algorithm that works on almost all moduli.Why almost?
Well, it works on odd moduli and all primes besides 2 are odd.
Anyway, if we have an odd modulus, we can compute the Montgomery modular exponentiation instead of Modular exponentiation.
Converting to-from the Montgomery domain only requires
montMul
by R² or by 1.Reconciling Montgomery and even modulus.
One issue is that in the Ethereum Virtual Machine, modexp can receive any modulus, not just odd.
Thankfully, we can invoke the Chinese Remainder Theorem (CRT) (https://en.wikipedia.org/wiki/Chinese_remainder_theorem), that states that if your modulus
m = a * b
with a and b coprimes, you can computemod a
andmod b
separately and it gives you a way to recombinemod m
So if you have an even number, you split it into
a = 2ᵏ
an even power of 2, and b an odd number. Odd numbers are coprimes with power of 2 so you can for sure apply the CRT.x mod 2ᵏ == x and (2ᵏ - 1)
.Engineering & Implementation
Now that we have the theory, let's look at engineering problems and reference code.
Fast Montgomery multiplication
There are many ways to multiply bigints, there were categorized in Acar thesis: https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf
On modern architectures, with MULX/ADCX/ADOX, CIOS is the fastest. without MULX/ADCX/ADOX (and so potentially on everything besides x86), FIPS is the fastest. Why? AFAIK data movement is better and less carries to save/restore.
The no-assembly algorithm for both is available here:
Fast large multiplication
One issue with multiplication is that it's O(n²) regarding the number of words since we need to multiply each word in each multiplicand with each word in the other.
The Karatsuba algorithm has a complexity of about O(n¹˙⁵) with a constant factor that becomes negligeable at around 8~12 words (so 512-768, to be measured), https://en.wikipedia.org/wiki/Karatsuba_algorithm
Fast exponentiation
Now we can look into exponentiation. The basic algorithm is square-and-multiply https://en.wikipedia.org/wiki/Exponentiation_by_squaring
You scan the bits of the exponent (from left-to-right, i.e. MSB-to-LSB or right to left, i.e. LSB-to-MSB, both variants are possible), you always square and then you multiply by something (depending on your scanning direction) if the bit if set.
However, you can precompute bit pattern, for example 1101, 1001, 0111 and instead of doing 2 or 3 multiplication, ensure that you only do 1 every 4 squarings.
That's called the
window
method.An example, fixed window method is available in Constantine, . As there is no constant-time requirement in Stint it can be simplified, but I needed to ensure I could use the value of secret bits for windows (for example RSA uses modular exponentiation) without revealing what the secret bits are. https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L614-L836
Sliding window
An extra optimization is using window of variable sizes called sliding window. I don't have an implementation of that (cannot be made constant-time) but Wikipedia has a pseudocode https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Sliding-window_method
NAF/signed recoding
In your research you might come across NAF or signed recoding, this only applies to elliptic curves (because inversion 1/a (mod m) is not cheap in modular arithmetic but -P is cheap in elliptic curve)
Montgomery domain, conversion and constants
Computing R and R² can be done like this: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/config/precompute.nim#L307-L350
You will also need 1/M0 (mod 2⁶⁴) with M0 being the first limb of your modulus M.
Wrapup
The implementation steps should be:
montMul(a, R²) = a' = aR (mod m)
montMul(a', 1) = a (mod m)
and use Clang