0xPolygonMiden / miden-vm

STARK-based virtual machine
MIT License
615 stars 152 forks source link

Proposition to refactor the multiset checks across the codebase #1388

Open plafer opened 1 month ago

plafer commented 1 month ago

I'd like to suggest a refactoring of our various multiset checks; this will be necessary soon enough when we integrate them into LogUp-GKR anyways. The current abstractions we have couple a few distinct concerns:

Now, with the new LogUp-GKR proving system, we will need to rewrite all multiset checks that use the AuxColumnBuilder, since as described, are currently coupled to a running product column implementation.

The ideal multiset check abstractions would allow us to describe a multiset check instance once, and then swap between various compatible ways of proving it (e.g. running product column, running sum column, GKR, etc.). Also ideally, our abstractions for building running sum/product columns would be general enough to work out of the box with any "compatible" multiset check instance (more on that later). While we are thinking of switching over all multiset checks to be proved using LogUp-GKR, I still think it would be valuable to quickly swap some multiset checks in/out of LogUp-GKR, and be proved using a running product/sum column instead, for the purposes of benchmarking, and A/B testing more generally.

Below, I will first describe how I think about the different "variants" of multiset checks, which will justify the set of abstractions that I describe next.

Background

The 2 main categories of multiset checks are, as identified in our docs,

  1. multiset checks without multiplicities (that we'll just call "multiset check")
  2. multiset checks with multiplicities

It is important to distinguish between the two, since for example, a running product column can be used to prove/instantiate multiset checks, but not multiset checks with multiplicities (at least not with general polynomials as multiplicities). We will use the following mathematical models for each:

  1. multiset checks

$$ \prod{i=0}^{n-1} \prod{j=0}^{k-1} a{ij}(\psi) = \prod{i=0}^{n-1} \prod{j=0}^{k-1} b{ij}(\psi) $$

  1. multiset checks with multiplicities

$$ \prod{i=0}^{n-1} \prod{j=0}^{k-1} a{ij}(\psi)^{m{a{ij}}} = \prod{i=0}^{n-1} \prod{j=0}^{k-1} b{ij}(\psi)^{m{b{ij}}} $$

Here, $n$ is the length of the trace, and $k$ is the maximum number of terms that can be generated in a single row. Also, $\psi = (\psi0, \dots, \psi{c-1})$ is a set of random elements that the multiset elements are allowed to use. In the current codebase, these are the alphas that are used e.g. to merge virtual table columns together. Finally, $m{a{ij}}$ and $m{b{ij}}$ are polynomials derived from the trace that act as the multiplicity for the corresponding $a{ij}$ or $b{ij}$, respectively. Those currently don't have access to $\psi$, but if that turns out to be useful, they certainly could.

A very important point that is often overlooked is the question of which rows do $a{ij}(\psi)$ and $b{ij}(\psi)$ have access to? This will end up affecting how the underlying proving system (running product/sum/GKR) looks like. The 2 real options right now are:

  1. They only have access to row $i$.
  2. They have access to rows $i$ and $i+1$

Currently in the VM, the range checker is the only multiset check that only needs row $i$. All the other ones need rows $i$ and $i+1$. Note however that this is a property of the proving system, in that it doesn't only affect multiset checks, but any column built using it. To be concrete, the "s" auxiliary column in LogUp-GKR is a built as a running sum (since it implements an inner product); this running sum, even if not a multiset check, will be built different whether "s" needs access to row $i$ only or both rows $i$ and $i+1$.

Normally, you'd think that since 2 is more general than 1, every problem 1 could be modeled as 2. But there's one important caveat: 2 is not a strict generalization of 1, since in 2, the last row is not allowed to contain multiset elements (since they're defined in terms of 2 rows). This is the case with the range checker: we are building the running sum column in a way that you would build a running sum column for 2 (even though the range checker elements never use row i+1), since we ensure that the last row of the trace is a HALT operation, and hence cannot contain a range check request (see #1383). However, this isn't possible with LogUp-GKR's "s" column: there always needs to be a multiset element generated on the last row, and so we cannot reuse a running sum column built for 2.

There are a few ways to build a column for 1, but the one we've honed in for LogUp-GKR's "s" column is described in https://github.com/facebook/winterfell/issues/286. As for building a column for 2, this is exactly how the current AuxColumnBuilder does it.

UPDATE: it turns out that only needing access to row $i$ (and allowing multiset elements to be generated in the last row) is a little difficult to deal with. The only concrete use case that we have currently is the LogUp-GKR's "s" column, but we're probably going to get rid of it, and so we would have no real use case for option 1 anymore. As a result, we could simplify the proposed abstractions below by not making the Frame type generic.

Proposed abstractions

The key takeaways from the previous section are:

Multiset check abstractions

These are the types to encode multiset checks, along with a few examples.

/// Note: Contrary to our current `AuxColumnBuilder`, we model explicitly that each frame is allowed
/// to return multiple requests/responses.
///
/// Note: instead of return `Vec<E>`, we would probably read into a buffer instead.
trait MultisetCheck<E: FieldElement, Frame> {
    /// Returned flags restricted to be 0 or 1
    fn get_flags_for_request(&self, frame: &Frame) -> Vec<E>;
    /// Returned flags restricted to be 0 or 1
    fn get_flags_for_response(&self, frame: &Frame) -> Vec<E>;

    // The inputs to a request function are `frame: &Frame` and `alphas: &[E]`, (where `alphas` is
    // `\psi` in the writeup).
    fn get_requests(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>>;
    // The inputs to a response function are `frame: &Frame` and `alphas: &[E]`
    fn get_responses(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>>;
}

trait MultisetCheckWithMultiplicities<E: FieldElement, Frame> {
    /// Returned multiplicities could be any field element
    fn get_multiplicities_for_request(&self, frame: &Frame) -> Vec<E>;
    /// Returned multiplicities could be any field element
    fn get_multiplicities_for_response(&self, frame: &Frame) -> Vec<E>;

    fn get_requests(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>>;
    fn get_responses(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>>;
}

/// Given the same frame, we always know how to build  
impl<E: FieldElement, Frame, T> MultisetCheckWithMultiplicities<E, Frame> for T
where
    T: MultisetCheck<E, Frame>,
{
    // It is always better to model flags as multiplicities in LogUp
    fn get_multiplicities_for_request(&self, frame: &Frame) -> Vec<E> {
        self.get_flags_for_request(frame)
    }

    // It is always better to model flags as multiplicities in LogUp
    fn get_multiplicities_for_response(&self, frame: &Frame) -> Vec<E> {
        self.get_flags_for_response(frame)
    }

    fn get_requests(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>> {
        <Self as MultisetCheck<E, Frame>>::get_requests(self)
    }

    fn get_responses(&self) -> Vec<Box<dyn Fn(&Frame, &[E]) -> Vec<E>>> {
        <Self as MultisetCheck<E, Frame>>::get_responses(self)
    }
}

/// Represents one of our virtual table. They all don't need multiplicities, and all need row `i`
/// and `i+1`, hence we set the `Self::Frame = EvaluationFrame`.
struct VTable {}

impl<E: FieldElement + 'static> MultisetCheck<E, EvaluationFrame<E>> for VTable {
    fn get_flags_for_request(&self, frame: &EvaluationFrame<E>) -> Vec<E> {
        todo!()
    }

    fn get_flags_for_response(&self, frame: &EvaluationFrame<E>) -> Vec<E> {
        // Here, for example, this would return the flags
        //   1. 1 if opcode is `JOIN`, 0 otherwise
        //   2. 1 if opcode is `SPLIT`, 0 otherwise
        // which correspond to the return vector of `get_responses()`
        todo!()
    }

    fn get_responses(&self) -> Vec<Box<dyn Fn(&EvaluationFrame<E>, &[E]) -> Vec<E>>> {
        vec![Box::new(response_when_join), Box::new(response_when_split)]
    }

    fn get_requests(&self) -> Vec<Box<dyn Fn(&EvaluationFrame<E>, &[E]) -> Vec<E>>> {
        todo!()
    }
}

fn response_when_join<E: FieldElement>(frame: &EvaluationFrame<E>, alphas: &[E]) -> Vec<E> {
    todo!()
}

fn response_when_split<E: FieldElement>(frame: &EvaluationFrame<E>, alphas: &[E]) -> Vec<E> {
    todo!()
}

/// Our range checker on the other hand does need multiplicities, so we implement
/// `MultisetCheckWithMultiplicities` directly. As discussed in the background section, it is legal
/// to use both `CurrentRow` and `EvaluationFrame` as frames. We only implement with `Evaluation` in
/// this example.
struct RangeChecker;

impl MultisetCheckWithMultiplicities<Felt, EvaluationFrame<Felt>> for RangeChecker {
    fn get_multiplicities_for_request(
        &self,
        frame: &EvaluationFrame<Felt>,
    ) -> Vec<Felt> {
        todo!()
    }
    fn get_multiplicities_for_response(
        &self,
        frame: &EvaluationFrame<Felt>,
    ) -> Vec<Felt> {
        todo!()
    }
    fn get_requests(
        &self,
    ) -> Vec<Box<dyn Fn(&EvaluationFrame<Felt>, &[Felt]) -> Vec<Felt>>> {
        todo!()
    }
    fn get_responses(
        &self,
    ) -> Vec<Box<dyn Fn(&EvaluationFrame<Felt>, &[Felt]) -> Vec<Felt>>> {
        todo!()
    }
}

Running product/sum column abstractions

These are the abstractions that build running product/sum columns from a multiset check.

/// Allows to tune how each running product/sum columns will be initialized.
trait FoldColumnInitializer<E: FieldElement> {
    fn init_requests(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E;
    fn init_responses(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E;
}

struct DefaultRunningProductInitialization;

impl<E: FieldElement> FoldColumnInitializer<E> for DefaultRunningProductInitialization {
    fn init_requests(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E {
        E::ONE
    }

    fn init_responses(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E {
        E::ONE
    }
}

struct DefaultRunningSumInitialization;

impl<E: FieldElement> FoldColumnInitializer<E> for DefaultRunningSumInitialization {
    fn init_requests(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E {
        E::ZERO
    }

    fn init_responses(&self, _main_trace: &MainTrace, _alphas: &[E]) -> E {
        E::ZERO
    }
}

/// Running sum column "backend" for `MultisetCheckWithMulticiplities`.
///
/// Note: Because all `MultisetCheck` get a blanket implementation for
/// `MultisetCheckWithMultiplicities`, this also allows a running sum column to be built for all
/// `MultisetCheck`.
fn build_running_sum_column<E: FieldElement>(
    initializer: impl FoldColumnInitializer<E>,
    multiset_check: impl MultisetCheckWithMultiplicities<E, EvaluationFrame<E>>,
    main_trace: &MainTrace,
    alphas: &[E],
    logup_randomness: E,
) -> Vec<E> {
    // Similar to `AuxColumnBuilder::build_aux_column`, but taking into account multiplicities
    todo!()
}

/// Running product column "backend" for `MultisetCheckWithMulticiplities`.
fn build_running_product_column<E: FieldElement>(
    initializer: impl FoldColumnInitializer<E>,
    multiset_check: impl MultisetCheck<E, EvaluationFrame<E>>,
    main_trace: &MainTrace,
    alphas: &[E],
) -> Vec<E> {
    // Similar to `AuxColumnBuilder::build_aux_column`
    todo!()
}

LogUp-GKR integration

Below is an idea of how the current LogUp-GKR backend would use the new multiset check abstractions, although we probably will end up doing something slightly different since LogUp-GKR is more in flux. This serves as a small example of how it would work.

/// Modification of the current function with the same name.
pub fn evaluate_fractions_at_main_trace_query<const NUM_RESPS: usize, const NUM_REQS: usize, E>(
    multiset_checks: Vec<
        impl MultisetCheckWithMultiplicities<E, EvaluationFrame<E>>,
    >,
    main_trace: &MainTrace,
    row_idx: usize,
    log_up_randomness: &[E],
) -> [[E; NUM_WIRES_PER_TRACE_ROW]; 2]
where
    E: FieldElement,
{
    // Assuming we keep this interface, this would simply query the multiset checks and return the
    // [multiplicies, requests & responses]
    todo!()
}
bobbinth commented 1 month ago

Thank you for such a detailed write up! The general direction of this makes a lot of sense to me - but I do have a few questions/comments.

First, the functions in the proposed interface return vectors in many cases. Given that these will be called for every row in an execution trace, I've been trying to stay away from returning vectors as it could result in extra allocations. I haven't actually measured the impact of returning vectors - so, to be honest, not sure if it is negligible or quite significant. But I think before we commit to using vectors in the interface, would be good to understand if we'll experience a noticeable degradation in performance. Or alternatively, we could refactor interfaces to avoid allocating vectors on every call.

Also related to the above, I wonder what the impact of using Box<dyn Fn(&EvaluationFrame<E>, &[E]) -> Vec<E>> would be as I think the compiler would not be able to inline calls to these functions.

For LogUp-GKR, would we actually need this structure? If building of the Lagrange kernel and s columns (assuming we still need it) moves to Winterfell, I think we would not need to build any auxiliary columns here. The only thing that would be needed are the changes to the Air interfaces discussed in https://github.com/0xPolygonMiden/miden-vm/issues/1386#issuecomment-2227899609. So, I think things like MultisetCheck and MultisetCheckWithMultiplicities and their implementations should probably live in the air crate of the VM.

I also wonder if there is value in separating MultisetCheck and MultisetCheckWithMultiplicities. I understand that they are conceptually different, but it seems like their interfaces and treatments are basically the same. Maybe we could collapse them into a single interface?

Also, when building request/response flags and values, we may need access to periodic column values - so, it might make sense to add them to the input parameters.

plafer commented 1 month ago

First, the functions in the proposed interface return vectors in many cases. Given that these will be called for every row in an execution trace, I've been trying to stay away from returning vectors as it could result in extra allocations.

Indeed, I was only returning Vec<E> for illustration purposes; the MultisetCheck trait has a note about that, but maybe that was too subtle of a place to put it. In any case, yes if we find that it is indeed faster to write into a buffer instead, then we should modify the interface accordingly.

I wonder what the impact of using Box<dyn Fn(&EvaluationFrame<E>, &[E]) -> Vec<E>> would be as I think the compiler would not be able to inline calls to these functions.

Good point. We should then modify the interfaces to e.g. (again, not necessarily returning Vec):

trait MultisetCheck<E: FieldElement, Frame> {
    // ...

    /// The implementation takes care of evaluating all its "inner functions", 
    /// and only returning the result.
    fn get_requests(&self, frame: &Frame, alphas: &[E]) -> Vec<Vec<E>>;

    // ...
}

For LogUp-GKR, would we actually need this structure?

With LogUp-GKR, we would only need the MultisetCheck and MultisetCheckWithMultiplicities traits, which would be used by the evaluation of the input layer.

However, I think it's still worth implementing the running sum/product columns (although at a lower priority), specifically for A/B testing. I would like to get to a point where, for example, answering the question "would it be faster to use a running product column for this bus instead of adding it to LogUp-GKR?" could be answered in a matter of at most a few hours. If it turned out that it was, it would have repercussions on AirScript, but I'm punting that problem down the road for if/when we get there.

I also wonder if there is value in separating MultisetCheck and MultisetCheckWithMultiplicities.

The main value is to allow proving systems to only be implemented for the appropriate version of a multiset check. The biggest case we want to avoid is allowing a running product column to prove a MultisetCheckWithMultiplicities (for which it would not be possible to write constraints in general).

And in fact they are treated differently. For example, RangeChecker only implements MultisetCheckWithMultiplicities, and so would not be able to be proved by a running product column.

The decision to represent flags explicitly in the MultisetCheck was only a convenience for the user. When flags are explicitly represented, then we know how to build a MultisetCheckWithMultiplicities from a MultisetCheck, since it is always better to put the flags in the multiplicity position. And so we have a blanket implementation that does just that. And so ultimately, the relationship we have is "every MultisetCheck is a MultisetCheckWithMultiplicities but not every MultisetCheckWithMultiplicities is a MultisetCheck.

Also, when building request/response flags and values, we may need access to periodic column values - so, it might make sense to add them to the input parameters.

Ah yes, we should add those too.

bobbinth commented 1 month ago

I think it's still worth implementing the running sum/product columns (although at a lower priority), specifically for A/B testing. I would like to get to a point where, for example, answering the question "would it be faster to use a running product column for this bus instead of adding it to LogUp-GKR?" could be answered in a matter of at most a few hours.

Agreed!

The main value is to allow proving systems to only be implemented for the appropriate version of a multiset check. The biggest case we want to avoid is allowing a running product column to prove a MultisetCheckWithMultiplicities (for which it would not be possible to write constraints in general).

And in fact they are treated differently. For example, RangeChecker only implements MultisetCheckWithMultiplicities, and so would not be able to be proved by a running product column.

Makes sense. But I'm wonder since there are no interface differences whether MultisetCheckWithMultiplicities could be just a marker trait. That is, it could look something like this:

pub trait MultisetCheck<E: FieldElement> {
    ...
}

pub trait MultisetCheckWithMultiplicities<E: FieldElement>: MultisetCheck<E>;

Also, thinking about this more, I'm wondering if should combine request/response and their corresponding flags/multiplicities into a single struct. Maybe something like:

pub struct Request<E: FieldElement> {
    pub latch: E, // equivalent to numerator
    pub value: E, // equivalent to denominator
}

// Could also be combined with `Request` - but not sure under what time
pub struct Response<E: FieldElement> {
    pub latch: E, // equivalent to numerator
    pub value: E, // equivalent to denominator
}

And then MultisetCheck could look something like this:

pub trait MultisetCheck<F: FieldElement> {

    fn fill_requests<E: FieldElement<BaseField = F> + ExtensionOf<F>>(
        &self,
        frame: &EvaluationFrame<F>,
        periodic_values: &[F],
        &mut result[Request<E>],
    );

    fn fill_responses<E: FieldElement<BaseField = F> + ExtensionOf<F>>(
        &self,
        frame: &EvaluationFrame<F>,
        periodic_values: &[F],
        &mut result[Response<E>],
    );

}

We could also simplify this by keeping the inputs and output in the same filed, and then handle the final adjustment by randomness separately (i.e., the returned denominator would be just $v$ instead of $\alpha + v$). In this case, we'd have something like this:

pub trait MultisetCheck<E: FieldElement> {

    fn fill_requests(
        &self,
        frame: &EvaluationFrame<E>,
        periodic_values: &[E],
        linear_comb_rands: &[E],
        &mut result[Request<E>],
    );

    fn fill_responses(
        &self,
        frame: &EvaluationFrame<E>,
        periodic_values: &[E],
        linear_comb_rands: &[E],
        &mut result[Response<E>],
    );

}