kzrnm / ac-library-csharp

42 stars 5 forks source link

ModInt を実装 #46

Closed terry-u16 closed 4 years ago

terry-u16 commented 4 years ago

ModInt を実装しました。 #12

C++版に合わせ、コンパイル時に mod が決定するもの(StaticModInt<T>)と実行時にmodが決定するもの(DynamicModInt<T>)を作成しています。 DynamicModInt<T>はコンパイル時定数を用いた最適化ができない代わりに、乗算時にBarrett reductionを用いて剰余演算の高速化がなされています。

StaticModInt<T>については、 #12 の通りIStaticModインターフェースを実装した構造体を型引数として受け取ることで、コンパイル時定数を利用した剰余演算の最適化を可能としています。 よく使われる mod として、Mod1000000007およびMod998244353をデフォルトで用意しています。

DynamicModInt<T>の型引数の仕様については少々迷ったのですが、C++版に寄せて空のIDynamicModIDインターフェースを実装した構造体をIDとして渡し、mod別に型識別されるようにしています。 こちらもModID0, ModID1, ModID2をデフォルトで用意しています。

インターフェース定義等 ```C# /// /// コンパイル時に決定する mod を表します。 /// /// /// /// public readonly struct Mod1000000009 : IStaticMod /// { /// public uint Mod => 1000000009; /// public bool IsPrime => true; /// } /// /// public interface IStaticMod { /// /// mod を取得します。 /// uint Mod { get; } /// /// mod が素数であるか識別します。 /// bool IsPrime { get; } } public readonly struct Mod1000000007 : IStaticMod { public uint Mod => 1000000007; public bool IsPrime => true; } public readonly struct Mod998244353 : IStaticMod { public uint Mod => 998244353; public bool IsPrime => true; } /// /// 実行時に決定する mod の ID を表します。 /// /// /// /// public readonly struct ModID123 : IDynamicModID { } /// /// public interface IDynamicModID { } public readonly struct ModID0 : IDynamicModID { } public readonly struct ModID1 : IDynamicModID { } public readonly struct ModID2 : IDynamicModID { } ```

動作確認は以下で行いました。 ケースバイケースだとは思いますが、StaticModInt<T>の方が若干速いようです。(.NET Core Runtime 起動のオーバーヘッドを差し引くと3割程度高速?) StaticModInt<T>版(212ms):https://atcoder.jp/contests/exawizards2019/submissions/16742366 DynamicModInt<T>版(264ms):https://atcoder.jp/contests/exawizards2019/submissions/16742379

また雑にですが、各種演算についてテストケースを作成し、テストが通ることを確認しています。

テストケース ### `StaticModInt`版 ```C# using AtCoder; using Xunit; using ModInt = AtCoder.StaticModInt; namespace ModIntTest { public class StaticModIntTest { const int Mod = 1000000007; const int Seed = 42; const int N = 10000000; [Fact] public void ConstructorTest() { var rand = new XorShift(Seed); ConstructorSubTest(0); ConstructorSubTest(-1); ConstructorSubTest(Mod); ConstructorSubTest(-Mod); for (int i = 0; i < N; i++) { ConstructorSubTest((long)rand.Next()); } } private static void ConstructorSubTest(long x) { var m = new ModInt(x); x %= Mod; if (x < 0) { x += Mod; } Assert.Equal(x, m.Value); } [Fact] public void RawTest() { var rand = new XorShift(Seed); RawSubTest(0); RawSubTest(Mod - 1); for (int i = 0; i < N; i++) { RawSubTest((int)(rand.Next() % Mod)); } } private static void RawSubTest(int x) { var m = ModInt.Raw(x); Assert.Equal(x, m.Value); } [Fact] public void IncrementTest() { const long init = Mod - N + 5; var m = new ModInt(init); for (int i = 0; i < N; i++) { var expected = (init + i) % Mod; Assert.Equal(expected, m++); } } [Fact] public void DecrementTest() { const long init = Mod + N - 5; var m = new ModInt(init); for (int i = 0; i < N; i++) { var expected = (init - i) % Mod; Assert.Equal(expected, m--); } } [Fact] public void AddTest() { var rand = new XorShift(Seed); const long max = 1L << 60; for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % (max >> 1) - max); var b = (long)(rand.Next() % (max >> 1) - max); var ma = new ModInt(a); var mb = new ModInt(b); var expected = (a + b) % Mod; if (expected < 0) { expected += Mod; } var actual = ma + mb; Assert.Equal(expected, actual); } } [Fact] public void SubtractTest() { var rand = new XorShift(Seed); const long max = 1L << 60; for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % (max >> 1)) - max; var b = (long)(rand.Next() % (max >> 1)) - max; var ma = new ModInt(a); var mb = new ModInt(b); var expected = (a - b) % Mod; if (expected < 0) { expected += Mod; } var actual = ma - mb; Assert.Equal(expected, actual); } } [Fact] public void MultiplicationTest() { var rand = new XorShift(Seed); const int max = 1 << 30; for (int i = 0; i < N; i++) { var a = (int)(rand.Next() % (max >> 1)) - max; var b = (int)(rand.Next() % (max >> 1)) - max; var ma = new ModInt(a); var mb = new ModInt(b); var expected = (long)a * b % Mod; if (expected < 0) { expected += Mod; } var actual = ma * mb; Assert.Equal(expected, actual); } } [Fact] public void DivisionTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = rand.Next(Mod); var b = rand.Next(Mod); var divided = new ModInt(a) / b; var actual = ((long)divided.Value * b) % Mod; Assert.Equal(a, actual); } } [Theory] [InlineData(1, 1)] [InlineData(2, 8)] [InlineData(4, 4)] [InlineData(7, 13)] [InlineData(8, 2)] [InlineData(11, 11)] [InlineData(13, 7)] [InlineData(14, 14)] public void DivisionNotPrimeTest(int input, int expected) { var m = StaticModInt.Raw(input); var actual = 1 / m; Assert.Equal(expected, actual); } struct Mod15 : IStaticMod { public uint Mod => 15; public bool IsPrime => false; } [Fact] public void NegateTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var m = rand.Next(Mod); var actual = +m + -m; Assert.Equal(0, actual); } } [Fact] public void PowTest() { var rand = new XorShift(Seed); Assert.Equal(1, ModInt.Raw(100).Pow(0)); Assert.Equal(100, ModInt.Raw(100).Pow(1)); for (int i = 0; i < N; i++) { var x = rand.Next(Mod); var n = rand.Next(Mod); var actual = ModInt.Raw(x).Pow(n); long expected = 1; while (n > 0) { if ((n & 1) > 0) { expected *= x; expected %= Mod; } x = (int)((long)x * x % Mod); n >>= 1; } Assert.Equal(expected, actual); } } [Fact] public void EqualTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = rand.Next(Mod); var b = a + (long)Mod * rand.Next(Mod); var ma = new ModInt(a); var mb = new ModInt(b); Assert.True(ma == mb); Assert.False(ma != mb); } } [Fact] public void NotEqualTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % long.MaxValue); var b = (long)(rand.Next() % long.MaxValue); if (a % Mod == b % Mod) { continue; } var ma = new ModInt(a); var mb = new ModInt(b); Assert.False(ma == mb); Assert.True(ma != mb); } } [Fact] public void ToStringTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % long.MaxValue); var m = new ModInt(a); Assert.Equal((a % Mod).ToString(), m.ToString()); } } } } ``` ### `DynamicModInt`版 ほぼ`StaticModInt`版のコピペです。 ```C# using AtCoder; using Xunit; using ModInt = AtCoder.DynamicModInt; namespace ModIntTest { public class DynamicModIntTest { const int Mod = 1000000007; const int Seed = 42; const int N = 10000000; public DynamicModIntTest() { DynamicModInt.Mod = 1000000007; DynamicModInt.Mod = 15; } [Fact] public void ConstructorTest() { var rand = new XorShift(Seed); ConstructorSubTest(0); ConstructorSubTest(-1); ConstructorSubTest(Mod); ConstructorSubTest(-Mod); for (int i = 0; i < N; i++) { ConstructorSubTest((long)rand.Next()); } } private static void ConstructorSubTest(long x) { var m = new ModInt(x); x %= Mod; if (x < 0) { x += Mod; } Assert.Equal(x, m.Value); } [Fact] public void RawTest() { var rand = new XorShift(Seed); RawSubTest(0); RawSubTest(Mod - 1); for (int i = 0; i < N; i++) { RawSubTest((int)(rand.Next() % Mod)); } } private static void RawSubTest(int x) { var m = ModInt.Raw(x); Assert.Equal(x, m.Value); } [Fact] public void IncrementTest() { const long init = Mod - N + 5; var m = new ModInt(init); for (int i = 0; i < N; i++) { var expected = (init + i) % Mod; Assert.Equal(expected, m++); } } [Fact] public void DecrementTest() { const long init = Mod + N - 5; var m = new ModInt(init); for (int i = 0; i < N; i++) { var expected = (init - i) % Mod; Assert.Equal(expected, m--); } } [Fact] public void AddTest() { var rand = new XorShift(Seed); const long max = 1L << 60; for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % (max >> 1) - max); var b = (long)(rand.Next() % (max >> 1) - max); var ma = new ModInt(a); var mb = new ModInt(b); var expected = (a + b) % Mod; if (expected < 0) { expected += Mod; } var actual = ma + mb; Assert.Equal(expected, actual); } } [Fact] public void SubtractTest() { var rand = new XorShift(Seed); const long max = 1L << 60; for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % (max >> 1)) - max; var b = (long)(rand.Next() % (max >> 1)) - max; var ma = new ModInt(a); var mb = new ModInt(b); var expected = (a - b) % Mod; if (expected < 0) { expected += Mod; } var actual = ma - mb; Assert.Equal(expected, actual); } } [Fact] public void MultiplicationTest() { var rand = new XorShift(Seed); const int max = 1 << 30; for (int i = 0; i < N; i++) { var a = (int)(rand.Next() % (max >> 1)) - max; var b = (int)(rand.Next() % (max >> 1)) - max; var ma = new ModInt(a); var mb = new ModInt(b); var expected = (long)a * b % Mod; if (expected < 0) { expected += Mod; } var actual = ma * mb; Assert.Equal(expected, actual); } } [Fact] public void DivisionTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = rand.Next(Mod); var b = rand.Next(Mod); var divided = new ModInt(a) / b; var actual = ((long)divided.Value * b) % Mod; Assert.Equal(a, actual); } } [Theory] [InlineData(1, 1)] [InlineData(2, 8)] [InlineData(4, 4)] [InlineData(7, 13)] [InlineData(8, 2)] [InlineData(11, 11)] [InlineData(13, 7)] [InlineData(14, 14)] public void DivisionNotPrimeTest(int input, int expected) { var m = DynamicModInt.Raw(input); var actual = 1 / m; Assert.Equal(expected, actual); } [Fact] public void NegateTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var m = rand.Next(Mod); var actual = +m + -m; Assert.Equal(0, actual); } } [Fact] public void PowTest() { var rand = new XorShift(Seed); Assert.Equal(1, ModInt.Raw(100).Pow(0)); Assert.Equal(100, ModInt.Raw(100).Pow(1)); for (int i = 0; i < N; i++) { var x = rand.Next(Mod); var n = rand.Next(Mod); var actual = ModInt.Raw(x).Pow(n); long expected = 1; while (n > 0) { if ((n & 1) > 0) { expected *= x; expected %= Mod; } x = (int)((long)x * x % Mod); n >>= 1; } Assert.Equal(expected, actual); } } [Fact] public void EqualTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = rand.Next(Mod); var b = a + (long)Mod * rand.Next(Mod); var ma = new ModInt(a); var mb = new ModInt(b); Assert.True(ma == mb); Assert.False(ma != mb); } } [Fact] public void NotEqualTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % long.MaxValue); var b = (long)(rand.Next() % long.MaxValue); if (a % Mod == b % Mod) { continue; } var ma = new ModInt(a); var mb = new ModInt(b); Assert.False(ma == mb); Assert.True(ma != mb); } } [Fact] public void ToStringTest() { var rand = new XorShift(Seed); for (int i = 0; i < N; i++) { var a = (long)(rand.Next() % long.MaxValue); var m = new ModInt(a); Assert.Equal((a % Mod).ToString(), m.ToString()); } } } } ```
key-moon commented 4 years ago

ありがとうございます🙇 今日も遅くなってしまって申し訳ありません。 確認しました。