GiacomoPope / kyber-py

A pure python implementation of ML-KEM (FIPS 203) and CRYSTALS-Kyber
MIT License
182 stars 44 forks source link

Centered Binomial Distribution Sampling is slow #44

Closed GiacomoPope closed 1 month ago

GiacomoPope commented 1 month ago

Maybe there is a more pythonic way to do the CBD as this seems too slow, and is the bottleneck for keygen and encapsulated. I think the main culprit is the bytes_to_bits() method needed at the moment.

Keygen

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   205000    0.148    0.000    0.348    0.000 {built-in method builtins.sum}
      400    0.108    0.000    0.508    0.001 polynomials.py:56(cbd)
      400    0.079    0.000    0.080    0.000 polynomials.py:161(to_ntt)
   409600    0.069    0.000    0.069    0.000 polynomials.py:69(<genexpr>)
   409600    0.064    0.000    0.064    0.000 polynomials.py:68(<genexpr>)
      400    0.062    0.000    0.062    0.000 utils.py:11(bitstring_to_bytes)
   179200    0.060    0.000    0.060    0.000 {built-in method builtins.format}
      400    0.058    0.000    0.119    0.000 utils.py:1(bytes_to_bits)
Encaps

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   333100    0.275    0.000    0.637    0.000 {built-in method builtins.sum}
      500    0.141    0.000    0.575    0.001 polynomials.py:56(cbd)
      800    0.122    0.000    0.249    0.000 utils.py:1(bytes_to_bits)
   716800    0.116    0.000    0.116    0.000 polynomials.py:93(<genexpr>)
   233600    0.079    0.000    0.079    0.000 {built-in method builtins.format}
      300    0.074    0.000    0.075    0.000 polynomials.py:198(from_ntt)
   435200    0.073    0.000    0.073    0.000 polynomials.py:69(<genexpr>)
   435200    0.069    0.000    0.069    0.000 polynomials.py:68(<genexpr>)
    def cbd(self, input_bytes, eta, is_ntt=False):
        """
        Algorithm 2 (Centered Binomial Distribution)
        https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf

        Expects a byte array of length (eta * deg / 4)
        For Kyber, this is 64 eta.
        """
        assert 64 * eta == len(input_bytes)
        coefficients = [0 for _ in range(256)]
        list_of_bits = bytes_to_bits(input_bytes)
        for i in range(256):
            a = sum(list_of_bits[2 * i * eta + j] for j in range(eta))
            b = sum(list_of_bits[2 * i * eta + eta + j] for j in range(eta))
            coefficients[i] = (a - b) % 3329
        return self(coefficients, is_ntt=is_ntt)
GiacomoPope commented 1 month ago

By swapping the iterator with a slice in #45 some improvements were made. The obvious next thing would be to work with the bytes directly rather than sampling into bits first. This should be faster.

GiacomoPope commented 1 month ago

I'm happier with the performance now after #76