GaloisInc / cryptol-specs

A central repository for specifications of cryptographic algorithms in Cryptol
BSD 3-Clause "New" or "Revised" License
35 stars 7 forks source link

Consolidate and spec-ify NTT implementations #163

Open marsella opened 3 weeks ago

marsella commented 3 weeks ago

We have a couple of versions of NTT floating around (including in the ML-KEM implementation, several of the Dilithium versions, and standalone in Common/ntt.

None of them really look like the versions we have written in specs (ML-KEM and ML-DSA being the ones we'd most like to emulate). Those versions have dense nested loops with parameters that depend on other parameters and non-consecutive updates of sequence elements. It's hard to implement faithfully in Cryptol. See more discussion on #156.

The versions we have match other NTT reference implementations (I don't have sources but I guess there's a python library that matches the ML-KEM one and a C reference implementation for Dilithium).

There are also some other recursion-based fast versions that are better than the naive versions in the specs. Most specs are explicit that "any algorithm that's mathematically equivalent" is fine to use, but we need to make sure they're proven equivalent.

marsella commented 3 weeks ago

Here's the old recursive attempt at something spec-adherent:

/**
 * ```repl
 * :prove NaiveNTTsMatch
 * ```
 */
property NaiveNTTsMatch f = NaiveNTT' f == NaiveNTT f

private
    /**
     * Naive version of NTT, implemented using recursing instead of loops.
     * [FIPS-203] Algorithm 9.
     *
     * Note that this implementation is spread out across multiple functions
     * to support the use of numeric constraint guards.
     */
    NaiveNTT' : Rq -> Tq
    NaiveNTT' f = state.f_hat where
        // Step 1 - 2. Initialize `f_hat`, `i`.
        state0 = { z = 0, i = 1, f_hat = f}
        // Step 3. Initialize `len` and evaluate the body of the loop.
        state = len_loop`{len = 128} state0

    type State = { z : Z q, i : [8] , f_hat : Tq }

    // Step 3 - 13.
    len_loop : {len} (len <= 128) => State -> State
    len_loop state
        // Step 3: Stop if we're at the end of the loop.
        | len < 2 => state
        // Otherwise, we're in a valid loop iteration.
        | len >= 2 => state'' where
            // Evaluate the body of the loop...
            state' = start_loop`{len, 0} state
            // ...then start the next iteration.
            state'' = len_loop`{len / 2} state'

    // Steps 4 - 12.
    start_loop : {len, start} (fin len, fin start) => State -> State
    start_loop state
        // Step 4: Stop if we're at the end of the loop.
        | start >= 256 => state
        // Otherwise, we're in a valid loop iteration.
        | start < 256 => state''' where
            // Step 5.
            z = zeta ^^(BitRev7 state.i)
            // Step 6.
            i = state.i + 1
            // Save the changes from 5-6.
            state' = { z = z, i = i, f_hat = state.f_hat }
            // Step 7-11. Evaluate the `j`-loop.
            state'' = j_loop`{len, start, start}  state'
            // Start the next iteration of the `start` loop.
            state''' = start_loop`{len, start + 2 * len} state''

    // Steps 7 - 11.
    j_loop : {len, start, j} (start <= j, j <= start + len)
        => State -> State
    j_loop state
        // Step 7: Stop if we're at the end of the loop
        | (j == start + len) => state
        // This case is impossible to reach; `j + len` will always be a valid
        // index into `f_hat`. It's not possible to infer that from the type
        // constraints we have now, so it's stated explicitly.
        | (j + len >= 256) => state
        // Otherwise, we're in a valid loop iteration.
        | (j + len < 256, j < start + len) => state'' where
            // Step 8.
            t = state.z * state.f_hat @(`j + `len)
            // Step 9.
            f_hat' = set_f`{j + len} state.f_hat (state.f_hat @`j - t)
            // Step 10.
            f_hat'' = set_f`{j} f_hat' (f_hat' @`j + t)
            // Save the changes made in Steps 8-10.
            state' = {
                z = state.z,
                i = state.i,
                f_hat = f_hat''
            }
            // Start the next iteration of the loop.
            state'' = j_loop`{len, start, j+1} state'

            // Helper function to set the `idx`th value of the polynomial.
            set_f : {idx} (idx <= 255) => Tq -> Z q -> Tq
            set_f poly val = take`{idx} poly # [val] # drop`{idx + 1} poly
marsella commented 2 weeks ago

Also note that the bitrev7 function doesn't match the ML-KEM spec exactly, so we should poke at it to see if there's a better way to use it (without running into overflow issues) or if we should just document better.