status-im / nim-stint

Stack-based arbitrary-precision integers - Fast and portable with natural syntax for resource-restricted devices.
Apache License 2.0
83 stars 11 forks source link

Implementing fast modular exponentiation - a guide #126

Open mratsim opened 1 year ago

mratsim commented 1 year ago

The goal of this issue is to provide a guideline on how to fix https://github.com/status-im/nimbus-eth1/issues/1584.

Also pinging @treeform, @guzba on how to implement fast RSA (modexp is the bottleneck) and @dlesnoff for nim/bigint

Here is a write-up on how to implement fast modular exponentiation.

Recommended textbook for implementer:

We assume 64-bit words.

Vocabulary

In textbook, you'll encounter the term "reduction" which should be read "modular reduction" and just means the remainder after Euclidean division.

The assembly myth

First of all let's dispel a myth, that assembly is the key for speed. That's true but missing the big picture.

Some benchmarks on Constantine modular exponentiation on BN254 prime field by a 254-bit integer:

So the ratio assembly/no assembly is "only" 30% with Clang. And don't use GCC with bigints, it's just bad if you don't use assembly.

The big picture

Modular exponentiation is implemented through an algorithm called double-and-add (or for exponentiation multiply-and-square), which does as many modular squarings as the number of bits in the exponent and as many modular multiplications as the number of set bits in the exponents.

For random numbers, about 50% of the bits are set. Assuming 256-bits, that's 384 modular multiplications/squarings.

Each modular multiplication is naively a multiplication 256-bit x 256-bit -> 512-bit and then modulo a 256-bit number.

The bottleneck

Let's take https://www.agner.org/optimize/instruction_tables.pdf and have a look at the speed of the DIV instruction, which is necessary to compute modulo.

x86 started to be extremely optimized for BigInt after Broadwell which introduced ADCX and ADOX (and MULX was introduced in Haswell, Broadwell predecessor).

DIV on 64-bit input takes 36 cycles, and has a latency of up to 95 cycles (i.e. anything that depends on that result may wait up to 95 cycles before proceeding).

image

In comparison add and shifts are just 1 cycle. So anything that uses division starts with a heavy disadvantage. image image

Note: that disadvantage is still faster than doing bit-by-bit division like here (the algorithm is chosen if there is at most 8-bit of length difference between operands) https://github.com/status-im/nim-stint/blob/94fc521ee0f1e113d09ceeaa3568d4d7a6c0b67d/stint/private/uint_div.nim#L206-L230

Back of the napkin perf:

A 256-bit modular reduction will need 4 64-bit DIV (and other things). So we already have a cost of about 400 cycles. We need that on 384 operations, so 384x400 = 153600 cycles. That's over 5x more costly that my slow benchmark of GCC without assembly (well it's on 254-bit instead of 256-bit)

And there is a lot more work beside the divisions, see https://github.com/status-im/nim-stint/blob/94fc521ee0f1e113d09ceeaa3568d4d7a6c0b67d/stint/private/uint_div.nim#L131-L169

i.e. all the example implementations on Wikipedia are really slow: https://en.wikipedia.org/wiki/Modular_arithmetic#Example_implementations

How to avoid division

There are 2 main techniques to avoid costly divisions:

Barret reduction

https://en.wikipedia.org/wiki/Barrett_reduction

Instead of doing a*b mod m you do a*b*(2⁶⁴)ᵏ/m and then you shift by k words (i.e. divide by 2⁶⁴) This is called Barret reduction and is interesting when (2⁶⁴)ᵏ/m can be reused many times. k is chosen so that the division by m has an inconsequential rounding error.

Montgomery reduction

Montgomery reduction uses a similar approach to Barret reduction but with lower complexity at the price of the need to "transport"/convert the number being reduced to the "Montgomery domain"

In practice we do all computation on a' = aR (mod m) with R = (2⁶⁴)^(numWords) (mod m), numWords = 4 for 256-bit numbers.

Once we have numbers in the Montgomery domain, there is an operation called Montgomery modular multiplication (montMul) that does: montMul(a', b') = montMul(aR, bR) = abR (mod m)

montMul is the fastest modular multiplication algorithm that works on almost all moduli.

Why almost?

Well, it works on odd moduli and all primes besides 2 are odd.

Anyway, if we have an odd modulus, we can compute the Montgomery modular exponentiation instead of Modular exponentiation.

Converting to-from the Montgomery domain only requires montMul by R² or by 1.

Reconciling Montgomery and even modulus.

One issue is that in the Ethereum Virtual Machine, modexp can receive any modulus, not just odd.

Thankfully, we can invoke the Chinese Remainder Theorem (CRT) (https://en.wikipedia.org/wiki/Chinese_remainder_theorem), that states that if your modulus m = a * b with a and b coprimes, you can compute mod a and mod b separately and it gives you a way to recombine mod m

So if you have an even number, you split it into a = 2ᵏ an even power of 2, and b an odd number. Odd numbers are coprimes with power of 2 so you can for sure apply the CRT.

Engineering & Implementation

Now that we have the theory, let's look at engineering problems and reference code.

Fast Montgomery multiplication

There are many ways to multiply bigints, there were categorized in Acar thesis: https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf

On modern architectures, with MULX/ADCX/ADOX, CIOS is the fastest. without MULX/ADCX/ADOX (and so potentially on everything besides x86), FIPS is the fastest. Why? AFAIK data movement is better and less carries to save/restore.

The no-assembly algorithm for both is available here:

Fast large multiplication

One issue with multiplication is that it's O(n²) regarding the number of words since we need to multiply each word in each multiplicand with each word in the other.

The Karatsuba algorithm has a complexity of about O(n¹˙⁵) with a constant factor that becomes negligeable at around 8~12 words (so 512-768, to be measured), https://en.wikipedia.org/wiki/Karatsuba_algorithm

Fast exponentiation

Now we can look into exponentiation. The basic algorithm is square-and-multiply https://en.wikipedia.org/wiki/Exponentiation_by_squaring

You scan the bits of the exponent (from left-to-right, i.e. MSB-to-LSB or right to left, i.e. LSB-to-MSB, both variants are possible), you always square and then you multiply by something (depending on your scanning direction) if the bit if set.

However, you can precompute bit pattern, for example 1101, 1001, 0111 and instead of doing 2 or 3 multiplication, ensure that you only do 1 every 4 squarings.

That's called the window method.

An example, fixed window method is available in Constantine, . As there is no constant-time requirement in Stint it can be simplified, but I needed to ensure I could use the value of secret bits for windows (for example RSA uses modular exponentiation) without revealing what the secret bits are. https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L614-L836

Sliding window

An extra optimization is using window of variable sizes called sliding window. I don't have an implementation of that (cannot be made constant-time) but Wikipedia has a pseudocode https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Sliding-window_method

NAF/signed recoding

In your research you might come across NAF or signed recoding, this only applies to elliptic curves (because inversion 1/a (mod m) is not cheap in modular arithmetic but -P is cheap in elliptic curve)

Montgomery domain, conversion and constants

Computing R and R² can be done like this: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/config/precompute.nim#L307-L350

You will also need 1/M0 (mod 2⁶⁴) with M0 being the first limb of your modulus M.

Wrapup

The implementation steps should be:

and use Clang

mratsim commented 1 year ago

I have a local implementation based on Constantine (unfuzzed yet, only basic testing)

Performance improvement from 53.7x to 82.9x on 256 bits inputs image

bench code

import
  ../constantine/math/arithmetic,
  ../constantine/math/io/[io_bigints, io_fields],
  ../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
  ../constantine/math/config/[type_bigint, curves, precompute],
  ../constantine/platforms/[abstractions, codecs],
  ../helpers/prng_unsafe,
  std/[times, monotimes, strformat]

import stint

# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args

var rng: RngState
rng.seed(1234)

for i in 0 ..< 5:
  echo "\ni: ", i
  # -------------------------
  let M = rng.random_long01Seq(BigInt[bits])
  let a = rng.random_long01Seq(BigInt[bits])

  var exponent = newSeq[byte](expBits div 8)
  for i in 0 ..< expBits div 8:
    exponent[i] = byte rng.next()

  # -------------------------

  let aHex = a.toHex()
  let eHex = exponent.toHex()
  let mHex = M.toHex()

  echo "  base:     ", a.toHex()
  echo "  exponent: ", exponent.toHex()
  echo "  modulus:  ", M.toHex()

  # -------------------------

  var elapsedCtt, elapsedStint: int64

  block:
    var r: BigInt[bits]
    let start = getMonotime()
    r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
    let stop = getMonotime()

    elapsedCtt = inNanoseconds(stop-start)

    echo "  r Constantine:       ", r.toHex()
    echo "  elapsed Constantine: ", elapsedCtt, " ns"

  # -------------------------

  block:
    let aa = Stuint[bits].fromHex(aHex)
    let ee = Stuint[expBits].fromHex(eHex)
    let mm = Stuint[bits].fromHex(mHex)

    var r: Stuint[bits]
    let start = getMonotime()
    r = powmod(aa, ee, mm)
    let stop = getMonotime()

    elapsedStint = inNanoseconds(stop-start)

    echo "  r stint:             ", r.toHex()
    echo "  elapsed Stint:       ", elapsedStint, " ns"

  echo &"\n  ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
  echo "---------------------------------------------------------"
mratsim commented 1 year ago

Redid the bench vs GMP, Constantine is +/- 10% slower or faster (without assembly!)

image

bench

import
  ../constantine/math/arithmetic,
  ../constantine/math/io/[io_bigints, io_fields],
  ../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
  ../constantine/math/config/[type_bigint, curves, precompute],
  ../constantine/platforms/[abstractions, codecs],
  ../helpers/prng_unsafe,
  std/[times, monotimes, strformat]

import stint, gmp

const # https://gmplib.org/manual/Integer-Import-and-Export.html
  GMP_WordLittleEndian = -1'i32
  GMP_WordNativeEndian = 0'i32
  GMP_WordBigEndian = 1'i32

  GMP_MostSignificantWordFirst = 1'i32
  GMP_LeastSignificantWordFirst = -1'i32

# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args

var rng: RngState
rng.seed(1234)

for i in 0 ..< 5:
  echo "i: ", i
  # -------------------------
  let M = rng.random_long01Seq(BigInt[bits])
  let a = rng.random_long01Seq(BigInt[bits])

  var exponent = newSeq[byte](expBits div 8)
  for i in 0 ..< expBits div 8:
    exponent[i] = byte rng.next()

  # -------------------------

  let aHex = a.toHex()
  let eHex = exponent.toHex()
  let mHex = M.toHex()

  echo "  base:     ", a.toHex()
  echo "  exponent: ", exponent.toHex()
  echo "  modulus:  ", M.toHex()

  # -------------------------

  var elapsedCtt, elapsedStint, elapsedGMP: int64

  block:
    var r: BigInt[bits]
    let start = getMonotime()
    r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
    let stop = getMonotime()

    elapsedCtt = inNanoseconds(stop-start)

    echo "  r Constantine:       ", r.toHex()
    echo "  elapsed Constantine: ", elapsedCtt, " ns"

  # -------------------------

  block:
    let aa = Stuint[bits].fromHex(aHex)
    let ee = Stuint[expBits].fromHex(eHex)
    let mm = Stuint[bits].fromHex(mHex)

    var r: Stuint[bits]
    let start = getMonotime()
    r = powmod(aa, ee, mm)
    let stop = getMonotime()

    elapsedStint = inNanoseconds(stop-start)

    echo "  r stint:             ", r.toHex()
    echo "  elapsed Stint:       ", elapsedStint, " ns"

  block:
    var aa, ee, mm, rr: mpz_t
    mpz_init(aa)
    mpz_init(ee)
    mpz_init(mm)
    mpz_init(rr)

    aa.mpz_import(a.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, a.limbs[0].unsafeAddr)
    let e = BigInt[expBits].unmarshal(exponent, bigEndian)
    ee.mpz_import(e.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, e.limbs[0].unsafeAddr)
    mm.mpz_import(M.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, M.limbs[0].unsafeAddr)

    let start = getMonotime()
    rr.mpz_powm(aa, ee, mm)
    let stop = getMonotime()

    elapsedGMP = inNanoSeconds(stop-start)

    var r: BigInt[bits]
    var rWritten: csize
    discard r.limbs[0].addr.mpz_export(rWritten.addr, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, rr)

    echo "  r GMP:               ", r.toHex()
    echo "  elapsed GMP:         ", elapsedGMP, " ns"

  echo &"\n  ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
  echo &"  ratio GMP/Constantine: {float64(elapsedGMP)/float64(elapsedCtt):.3f}x"
  echo "---------------------------------------------------------"