crystal-lang / crystal

The Crystal Programming Language
https://crystal-lang.org
Apache License 2.0
19.4k stars 1.62k forks source link

Add self.powmod(e, m) to do efficienct modular exponentiation (self ** e) % m #12772

Open jzakiya opened 1 year ago

jzakiya commented 1 year ago

It would be nice to have an efficient|fast implementation that works for all Ints.

Really nice if could be n.powmod(e, m) but will settle for powmod(n, e, m), for (n ** e) % m to produce output (0 to m-1) of the same typeof(m).

Standard algorithm.

def powmod(b, e, m)
  return 0 if m == 1
  r = 1; b = b % m
  while e > 0
    r = (r * b) % m if e.odd?
    b = (b * b) % m
    e >>= 1
  end
  r
end
HertzDevil commented 1 year ago

Related: #7516, #8612 (these only deal with BigInt)

jzakiya commented 1 year ago

I've done some beginning benchmarks for powmod, and the given algorithm works. "Ideally" you want to use the smallest Int type possible to optimize the math speed. So if you have a b.to_big_i that can fit into a UInt32|64, etc, internally you'd like to convert it for speed, and if the modulus, convert the answer back to BigInt.

ATOW (at time of writing) this https://github.com/crystal-lang/crystal/pull/12773 is still not merged, but when it is, math between all Int types should be doable.

Meanwhile, using values from https://github.com/crystal-lang/crystal/issues/8612 work for this code.

require "big"

def powmod(b, e, m)
  return typeof(m).zero if m == 1
  r = typeof(m).new(1); b %= m
  while e > 0
    r = (r &* b) % m if e.odd?
    e >>= 1
    b = (b &* b) % m
  end
  r
end

b = "53583115773616729421957814870755484980404298242901134400501331255090818409243".to_big_i
e = "28948022309329048855892746252171976963317496166410141009864396001977208667916".to_big_i
m = "115792089237316195423570985008687907853269984665640564039457584007908834671663".to_big_i

puts "#{(r = powmod(b, e, m))}|#{r.class}"
#=> 75711134420273723792089656449854389054866833762486990555172221523628676983696|BigInt

Again, ideally I'd like this to be a method on Ints so I can do n.powmod(e, m) and chain methods on it.

jzakiya commented 1 year ago

Here are some benchmark examples.

require "big"
require "benchmark"

module PowMod
  def powmod(e, m)
    return typeof(m).zero if m == 1
    r = typeof(m).new(1); b = self % m
    while e > 0
      r = (r &* b) % m if e.odd?
      e >>= 1
      b = (b &* b) % m
    end
    r
  end

  def powmodx(e, m)
    return typeof(m).zero if m == 1
    e = e.to_u32 if e < UInt32::MAX
    r = typeof(m).new(1); b = self % m
    while e > 0
      r = (r &* b) % m if e.odd?
      e >>= 1
      b = (b &* b) % m
    end
    r
  end
end

struct Int; include PowMod end

def benchtest(n, e, m)
  Benchmark.ips do |b|
    b.report("powmod  ") { n.powmod  e, m }
    b.report("powmodx ") { n.powmodx e, m }
    #puts
  end
end

def powmodtest(n, e, m)
  puts "n = #{n}|#{n.class}, e = #{e}|#{e.class}, m = #{m}|#{m.class}"
  puts "powmod  = #{x = n.powmod(e, m)}|#{x.class}"
  puts "powmodx = #{x = n.powmodx(e, m)}|#{x.class}"
end

puts
n, e, m = 5, 23223, 46447
powmodtest n, e, m
benchtest n, e, m

puts
n, e, m = 5, 23223.to_big_i, 46447
powmodtest n, e, m
benchtest n, e, m

puts
n, e, m = 12u64, 321, 872321u64
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 12, 321.to_big_i, 872321
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 12.to_big_i, 321.to_big_i, 872321.to_big_i
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 912u64, 6321, 1872321u64
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 9125u64, 6543, 643327u64 # works when m = 643327 >= U|Int64
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 9125u64, 6543, 327u128
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 4842, 45720, 156u32
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 4842, 45720, 1
powmodtest n, e, m
benchtest n, e, m
puts
n, e, m = 4842, 45720, 2
powmodtest n, e, m
benchtest n, e, m

Results.

n = 5|Int32, e = 23223|Int32, m = 46447|Int32
powmod  = 46446|Int32
powmodx = 46446|Int32
powmod    24.42M ( 40.96ns) (± 0.27%)  0.0B/op   1.00× slower
powmodx   24.48M ( 40.84ns) (± 0.18%)  0.0B/op        fastest

n = 5|Int32, e = 23223|BigInt, m = 46447|Int32
powmod  = 46446|Int32
powmodx = 46446|Int32
powmod     2.32M (430.30ns) (± 0.85%)  481B/op   4.25× slower
powmodx    9.87M (101.32ns) (± 0.14%)  0.0B/op        fastest

n = 12|UInt64, e = 321|Int32, m = 872321|UInt64
powmod  = 705784|UInt64
powmodx = 705784|UInt64
powmod   850.61M (  1.18ns) (± 2.11%)  0.0B/op        fastest
powmodx  737.29M (  1.36ns) (± 1.77%)  0.0B/op   1.15× slower

n = 12|Int32, e = 321|BigInt, m = 872321|Int32
powmod  = 337673|Int32
powmodx = 337673|Int32
powmod     4.19M (238.57ns) (± 0.36%)  288B/op   3.74× slower
powmodx   15.69M ( 63.74ns) (± 0.21%)  0.0B/op        fastest

n = 12|BigInt, e = 321|BigInt, m = 872321|BigInt
powmod  = 705784|BigInt
powmodx = 705784|BigInt
powmod     1.25M (798.76ns) (± 0.19%)  897B/op   1.27× slower
powmodx    1.59M (628.17ns) (± 0.22%)  609B/op        fastest

n = 912|UInt64, e = 6321|Int32, m = 1872321|UInt64
powmod  = 215148|UInt64
powmodx = 215148|UInt64
powmod   864.83M (  1.16ns) (± 1.80%)  0.0B/op        fastest
powmodx  737.41M (  1.36ns) (± 1.65%)  0.0B/op   1.17× slower

n = 9125|UInt64, e = 6543|Int32, m = 643327|UInt64
powmod  = 570959|UInt64
powmodx = 570959|UInt64
powmod   873.84M (  1.14ns) (± 1.81%)  0.0B/op        fastest
powmodx  737.05M (  1.36ns) (± 1.64%)  0.0B/op   1.19× slower

n = 9125|UInt64, e = 6543|Int32, m = 327|UInt128
powmod  = 263|UInt128
powmodx = 263|UInt128
powmod   907.78M (  1.10ns) (± 2.48%)  0.0B/op        fastest
powmodx  852.35M (  1.17ns) (± 3.38%)  0.0B/op   1.07× slower

n = 4842|Int32, e = 45720|Int32, m = 156|UInt32
powmod  = 144|UInt32
powmodx = 144|UInt32
powmod    35.50M ( 28.17ns) (± 7.41%)  0.0B/op   1.02× slower
powmodx   36.32M ( 27.53ns) (± 6.53%)  0.0B/op        fastest

n = 4842|Int32, e = 45720|Int32, m = 1|Int32
powmod  = 0|Int32
powmodx = 0|Int32
powmod   634.97M (  1.57ns) (± 1.68%)  0.0B/op   1.00× slower
powmodx  635.39M (  1.57ns) (± 1.68%)  0.0B/op        fastest

n = 4842|Int32, e = 45720|Int32, m = 2|Int32
powmod  = 0|Int32
powmodx = 0|Int32
powmod    26.96M ( 37.09ns) (± 0.40%)  0.0B/op        fastest
powmodx   26.56M ( 37.64ns) (± 0.18%)  0.0B/op   1.01× slower
HertzDevil commented 1 year ago

The code given in https://github.com/crystal-lang/crystal/issues/12772#issuecomment-1327772578 does not work for general primitive integers due to the issue pointed out in https://github.com/crystal-lang/crystal/issues/13244#issuecomment-1492498676.

GMP also supports modular inverses by using a negative exponent. If we want that for primitive integers as well we will need additional algorithms (such as #9848).

jzakiya commented 1 year ago

FYI, Ruby has 2 ways to do modular exponentiation in its standard library.

This Stackoverflow post shows both. https://stackoverflow.com/questions/14785329/efficient-way-to-power-and-mod-in-ruby

And docs on pow function. https://ruby-doc.org/3.2.2/Integer.html#method-i-pow

I benchmarked both in Ruby versions of the MR primality test, and the pow version is all around faster, and you don't need to require Openssl to use it. Notice pow can be overloaded as a.pow(b) and a.pow(b, m). Its source code may be instructive. With the recent release of 3.2.2 Ruby has gotten seriously faster!

So whatever API you choose for its functionality, it really needs to be a standard method, because its use is ubiquitous in many fields of numerical programming.

HertzDevil commented 1 year ago

This is Ruby's implementation of the 2-argument #pow. A lot of the code paths invoke Ruby's own arbitrary-precision integer functions, even when the values are small enough. If primitive integers with double the size of long are available (DLONG) then #pow will use them to avoid overflows, similar to what I did in that comment. Ruby also doesn't support modular inverses.

jzakiya commented 1 year ago

Here's my Crystal implementation for modinv that I use in my prime sieves code. I pulled it off of Rosetta Code. https://rosettacode.org/wiki/Modular_inverse

# Compute modular inverse a^-1 to base m, e.g. a*(a^-1) mod m = 1
def modinv(a0, m0)
  return 1 if m0 == 1
  a, m = a0, m0
  x0, inv = 0, 1
  while a > 1
    inv -= (a // m) * x0
    a, m = m, a % m
    x0, inv = inv, x0
  end
  inv += m0 if inv < 0
  inv
end

.

jzakiya commented 1 year ago

I ran Ruby vs Crystal tests yesterday and Ruby 3.2.2 slaughters Crystal 1.7.3.

https://forum.crystal-lang.org/t/ruby-out-performs-crystal-significantly-on-this-numerical-algorithm/5538