0xPolygonZero / zk_evm

Apache License 2.0
68 stars 20 forks source link

Implement MPTs as linked lists #172

Open hratoanina opened 2 months ago

hratoanina commented 2 months ago

One of the goals of continuations was to allow multi-txn proving: proving all the transactions of a block at once to avoid hashing the state and the storage tries at the beginning and at the end of every transaction. The issue is that bundling all transactions together makes for a giant TrieDataSegment (hundreds of millions of cells), and that memory must be carried through all segments until the end; this is not viable.

One idea to majorly reduce this memory consumption is to represent instead all accessed addresses and storage cells with sorted linked lists (in the same way accessed lists are currently implemented). This has several advantages:

This would require heavy refactoring though, and the hardest issue would be hashing: we would need to be able to reconstruct the correct trie roots out of these linked lists, without compromising soundness. We have several ideas, and they all rely on a lot of non-deterministic prover inputs to virtually reconstruct the trie.

4l0n50 commented 2 months ago

Much, much less memory usage for the MPTs. You only store leaves (addresses and storage cells) during execution, and only need to figure out the MPT structure when hashing.

I think you will also need to store all the hashed nodes. Otherwise you can't enforce that the same nodes where used in both the initial and final mpt_hash computations.

This is a high level description of how to hash the sorted linked list "on the fly" for the initial trie:

- Let i = 1 and n = list[0].
- while i < list.len()
    - Guess the least common ancestor (LCA) between n and list[i],
    - Guess and store a merkle path from n to the LCA
    - Guess and store a merkle path from list[i] to LCA.children[the_right_nibble_of(list[i])]
    - Verify the merkle paths
    - n = LCA(n, list[i]) and i += 1
- Guess, store and verify a merkle path from n to the root (i.e. some node whose key is null).

The merkle paths are stored in some contiguous memory cells and each time you guess one it's appended at the end. For checking the merkle path between nodes n1, n2 with keys k1 and k2 s.t. k1 = k2 || s

- Let h = KECCAK(rlp([hex_prefix(k1), n1.payload])
- while s != null
    - Guess a prefix p of the key and check that s = p || s' 
    - If |s'| > 1 (extension node): s = p and h = KECCAK(rlp([hex_prefix(s'), h, []])) 
    - If |s'| = 1   (branch node):
        - read the next 16 hashed nodes h_i for i != s'  from the merkle proof segment
        - s = p and h = KECCAK(rlp([h_0, ..., h, ...h_17, []])),

For the final trie it should be exactly the same, but you reuse all the merkle paths instead of guessing them.

frisitano commented 2 months ago

This is a great idea. I wanted to throw out some thoughts of how this could be implemented. I will articulate the solution which is the simplest to reason about soundness but there are likely optimisations that can be made (for example omitting entries for keys associated with empty values).

We construct link lists consisting of key - value pairs. One assumption that must hold is that keys in the list are unique - i.e. there can not be two entries with the same key. There must be an entry in the list for every key that is read or written during execution, regardless of whether the value is empty or not.

Firstly, we need to authenticate the correctness of values and uniqueness of keys in these lists. This can be done by iterating through the lists and doing lookups against the state trie. We must assert that the value associated with the key corresponds to the value in the state trie. This just requires lookups. This could be done before the first transaction or we could keep a copy of the initial state of these link lists to perform after the last transaction.

As transactions are executed and state reads / writes are performed they must always match an entry in these linked lists and read or modify as required. After the final transaction we need to update the state trie to compute the new state root. We could naively insert every key - value pair from these linked lists, however, this could be quite inefficient as there is likely many key - value pairs that haven't been modified. As an optimisation we introduce an additional state write tracking linked list, this linked list keeps track of modified keys. In essence, every time a state write is performed we add the key (or index of the key?) to this tracking linked list. After the final transaction we iterate through this tracking linked list and lookup the modified key - value pairs and insert them into the state trie - this will allow us to compute the final state root.

The issue is that bundling all transactions together makes for a giant TrieDataSegment (hundreds of millions of cells), and that memory must be carried through all segments until the end; this is not viable.

I'm not familiar with how tree's are implemented but it sounds like every tree node is stored in memory and that is problematic when working with a lot of tree data. The way it worked in polygon Miden is that we only ever store the tree root in memory. When we perform an operation on the tree (read or write) we request data associated with the current node hash and it is provided by the prover via non determinism - the node data can be verified by simply hashing it. We then select the appropriate child from the node data and make the next request. In polygon Miden this process inserts node data directly into the hash chiplet AIR trace - the node data is never stored in VM memory. It sounds like the zk_evm works a bit different in which it stores nodes in memory? If this is the case then perhaps you only store the trie root in memory. For each trie operation you request the required trie witness from the prover and it is provided via non determinism and stored in memory. The trie operation is performed on this data and then at the end of the operation the data is cleared and only the new root retained. Alternatively you could read all trie data for all trees into memory at once but this would then yield the same issue relating to the size of the "TrieDataSegment" even if only for the final segment.

4l0n50 commented 2 months ago

Hi @frisitano and thanks for your input! Maybe you would like to take a look to this PR, because many of the problems you're mentioning were solved there (I hope!). If you have any question don't hesitate to ask us!

I'm not familiar with how tree's are implemented but it sounds like every tree node is stored in memory and that is problematic when working with a lot of tree data. The way it worked in polygon Miden is that we only ever store the tree root in memory.

The main issue appears to be that the trie data is stored in a somewhat "append-only" fashion, as Hamy mentioned. We think we can afford to store all the data (the leaves), but we want to avoid rehashing. For this reason, we only need to hash at the beginning to ensure it matches the input root. Once you check that, you know the linked list is correct, and you can access or modify it at a very low cost. At the end, you will need to hash again to obtain the final root. The only significant remaining task is to develop an MPT hashing algorithm for the list (you have only the leaves, but you want to obtain the same hash you would get if you had a full trie).

frisitano commented 2 months ago

Hi @4l0n50, I took a look at the PR and that provided some good insights. I have a few questions / comments that would aid my understanding and potentially add value to the discussion.

Some observations:

The method describes above does not introduce a new algorithm for virtually hashing the linked list but instead it leverages the pre-existing mechanisms around partial tree's and performs the computation of the new root by batching operations. I think virtual hashing of the linked list would possibly be more efficient but quite a bit more complex.

4l0n50 commented 2 months ago

When you say the lists are sorted, what are they sorted by? Are they sorted by state / storage key?

yes, by the key (IIRC keccak(address or slot respectively))

With the lists being "append-only" will the appended data be inserted in sorted order?

The lists are append only (though I think is possible to reuse deleted nodes). The data is always sorted by key, in the orde induced by the linked list (though the memory addresses are not necessarily sorted)

I believe the tree will have to be hashed twice, once for the read operation to authenticate the linked lists and once for the write operations to compute the new root. Indeed, you must do that. But you could delete the hash preimages (the accounts or storage values) and later retrieve them by rehashing, which is what we want to avoid.

What are we trying to optimise for? Minimise max memory overhead? Minimise memory transfer between segments?

The later, which is mostly data in Segment::TrieData. We also want to avoid to traverse the trie for each access, bc it takes O(log n) cykes while with the list O(1),

Keys in the linked list are sorted and as such neighbouring items in the list will share a large number of nodes in the tree.

I don't know if I understand. The list only contains the leaves, not the internal nodes. I think the same is related to the next three obsevations. Maybe i need to still digest what you have written.

Thanks for your questions! (I'll be back to you in a couple of hours)

frisitano commented 2 months ago

I don't know if I understand. The list only contains the leaves, not the internal nodes. I think the same is related to the next three obsevations. Maybe i need to still digest what you have written.

So I am proposing that when we need to compute the new state root we request the partial tree (provided via non-determinism) required to perform the batch of tree operations for that specific segment.

We can perform 1 batch of tree operations per proving segment. i.e. request state witness for batch of operations from prover via non-determinism, perform batch of operations, compute new root, drop tree, retain root, end segment

The partial tree / state witness is essentially all nodes required for the batch of operations. However, these nodes are ephemeral - at the end of the batch / segment we have a new state root - we can drop Segment::TrieData memory. Then at the start of the next segment we request (once again via non-determinism) the state witness (associated with this new state root) required to perform the next batch of operations.

4l0n50 commented 2 months ago

Ah yes I see. I guess in the end it's the same we were proposing. But there's a subtlety, because you don't hash at the end of the segment but at the end of the txn. And this means that there must be some correlation between the guessed partial trie in the begging of the txn and what you use at the end (maybe in another segment), so you can't really drop those partial tries. Otherwise a malicious prover can guess a different partial trie (diff hashed nodes) changing, for example, untouched accounts.

4l0n50 commented 2 months ago

Or you could even hash only at the end of the block (or whenever your state transition finishes)

frisitano commented 2 months ago

Yes the idea would be that you only hash at the end of the block / when state transitions are completed. I think we are saying the same thing generally, the only difference is that I am proposing reusing the existing machinery relating to partial tries to achieve it instead of introducing a new algorithm for computing the root from the linked lists.

I think we could even defer authentication of the original linked lists to the end of the block as well. The way this could work is the prover provides the link lists at the start of the block, these would be unauthenticated. As transactions are executed, if there is a change to some original value in the linked list, we keep a copy of the original value. At the end of the block when we read in the partial trees to compute the new state root we can authenticate the original values and also apply any updates at the same time.

4l0n50 commented 2 months ago

Yes, I agree. And it's a very good idea to defer hashing to the end of state transition! Actually I was not clear how to reuse the original hashing machinery because of this required "correlation" (maybe it was still possible and I was just confused). But with your idea it's pretty direct to do so!

4l0n50 commented 2 months ago

I just wonder if there is any problem with deferring the validation of the list. Something like making the prover run for a lot of time just to realize at the end that the initial list was incorrect. But right now I can't see that this could be an issue.

frisitano commented 2 months ago

You are correct, without the introduction of this deferral you wound't be able to reuse the hashing from the linked list authentication. I wasn't sure what you meant by "rehash" but I understand this is what you are referring to.

I just wonder if there is any problem with deferring the validation of the list. Something like making the prover run for a lot of time just to realize at the end that the initial list was incorrect. But right now I can't see that this could be an issue.

It would be the prover itself that is providing the initial linked list, there would be no external source that would be able to influence the lists. So provided the prover code is sound I don't think this would be a problem. It would be a case of the prover trying to fool themselves if they provided an invalid initial list.