0xPolygonZero / plonky2

Apache License 2.0
745 stars 270 forks source link

Add `Field::shifted_powers` and some iterator niceties #1599

Closed gio256 closed 4 weeks ago

gio256 commented 4 weeks ago

Adds the Field::shifted_powers method from Plonky3 along with a few additions to the Powers iterator implementation.

Please pick and choose what's useful.

<Powers<F> as Iterator>::nth in particular might require some tweaking:

gio256 commented 4 weeks ago

The reason for the PR as a whole was to hopefully make it easier to replace as many instances of F::from_canonical_usize(1 << n) as possible (for safety and readability reasons). Originally, I thought nth would be useful for bit shifts for example.

// n is the (hypothetical) number of bits to shift by.
// f_n is a flag indicating whether this is really the n to shift by.
// All bits in the output with index < n are 0.
// m in {n..32} is the LE index of a possibly nonzero bit in the output.
// We want to line up the nth bit of the output with the 0th bit of the input,
// so we take `output_bits[m] = input_bits[m - n]`.
let shift_left: P = lv
    .shift_by_indices
    .into_iter()
    .enumerate()
    .flat_map(|(n, f_n)| {
        (n..32)
            .zip(P::Scalar::TWO.powers().skip(n))
            .map(move |(m, base)| f_n * lv.input_bits[m - n] * base)
    })
    .sum();

But, this is more efficiently done with Field::shifted_powers anyway. nth is used under the hood by iterator adapters like skip and step_by, so if those seem useful in the context of powers() than maybe it's worth keeping. Otherwise, I'm happy to remove it.

gio256 commented 4 weeks ago

Thinking about this more, I gave a misleading example as using Field::shifted_powers in the code above doesn't actually accomplish anything.

let shift_left: P = lv
    .shift_by_indices
    .into_iter()
    .enumerate()
    .flat_map(|(n, f_n)| {
        (n..32)
            .zip(P::Scalar::TWO.powers_shifted(P::Scalar::from_canonical_u32(1 << n)))
            .map(move |(m, base)| f_n * lv.input_bits[m - n] * base)
    })
    .sum();

It's a small thing, but this still requires verifying that at no point in the loop does 1 << n exceed u32::MAX or the order of the field. For that reason, I think the version using Iterator::skip is slightly better if the overhead is negligible. So, I'll update my (soft) pitch for adding nth to the Iterator implementation:

matthiasgoergens commented 4 weeks ago

Nice, I miss shifted_powers!

gio256 commented 4 weeks ago

I merged upstream (#1601) so the tests would run, but let me know if I should squash.

Otherwise, this should be ready to go. Thanks for the feedback.