vyperlang / vyper

Pythonic Smart Contract Language for the EVM
https://vyperlang.org
Other
4.83k stars 788 forks source link

VIP: Jump table optimization #2386

Closed charles-cooper closed 12 months ago

charles-cooper commented 3 years ago

Simple Summary

Reduce code size and gas usage for jump tables. Introduce a JUMPTABLE macro in LLL which can be optimized for any backend. This is a non-user facing optimization.

Motivation

Vyper currently uses a jump table at the contract entry point to determine which function to call into. The format of the jump table looks something like

(if caller_supplied_method_id == method_identifier(A())
  (... A() body...)
)
(if caller_supplied_method_id == method_identifier(B())
  (... B() body...)
)
; and so on

In assembly this translates to

# "jump table fragment"
# pretend caller_supplied_method_id is already on the stack
PUSH4 0x12345678 # method_identifier(A())
EQ # == x
ISZERO # not
PUSH2 __next__
JUMPI # skip
# ... A() body...
__next__ JUMPDEST
# repeat for B and so on

There are a couple issues with this. First of all the search is linear. If a contract has 80 methods and the most commonly called method happens to be the last method compiled into the contract code, that "jump table fragment" needs to be executed 80 times. On average, it would need to be executed 40 times, for an average overhead of 23 gas per fragment, which would translate into 920 gas overhead just to find the correct entry point for a function.

Second of all, there is some code size overhead. The assembly costs 12 bytes per signature, and while it looks fairly efficient, I'll show below that the code size can be reduced to overhead of 3 bytes per signature - only the jumptable locations need to be stored in the code.

Specification

Our basic goal is to minimize the lookup cost and space for determining the entry point of a function.

On x86, jump tables are a common optimization strategy for switch/case statements. When the variable being "switched" has relatively dense values, the jumptable will have a layout like the following:

switch(x) {
  case '1': ...; break;
  case '2': ...; break;
  case '4': ...; break;
  ...
}

The code would first jump into a lookup table of jump destinations, and then jumps from there into the correct code, something like

__jumptable = {
  0x1234, // case 1
  0x5678, // case 2
  0x0000, // fail/invalid
  0xabcd, // case 4
}
goto __jumptable[x - '1'] // glossing over the detail of handling the "fail/invalid" case

This gives us a constant lookup time for the code location of any given method. The problem with using this approach directly for our use case is that a) code size is at a high premium on the EVM so we would like to avoid as much wasted space as possible, b) we can't use the dense lookup table representation anyways because our method identifiers are 4 bytes; we would need contracts that are 2**32 bytes long mostly filled with zeros to implement a jump table.

If we use a so-called "perfect" hash function, we can have our cake and eat it. We need to construct a hash function which is dense and has unique location in the lookup table for every input. The simplest hash function we could use would be mulmod(method_id, MAGIC_NUMBER, num_methods), where MAGIC_NUMBER is computed at compile time (yes, this could be very expensive unless we find a good algorithm). Then, our entry point computation would look like this:

# Copy jump table into memory
codecopy(0, __jumptable, 2 * NUM_METHODS)

# Grab the method id from calldata
method_id = shr(calldataload(0), 28)

# Index into our jumptable
idx = mulmod(method_id, MAGIC_NUMBER, NUM_METHODS)

# Get the 2-byte jumpdest
_dst = mload(idx) & 0xFFFF
jump(_dst)

__jumptable
0x1234
0x5678
0xabcd

... # rest of contract

This reduces the per-method overhead to 3 bytes (2 bytes for each entry in the jump table and 1 byte to mark the entry points as valid with JUMPDESTs) / 0.1875 gas (CODECOPY requires 3 gas per word, and we can load 16 jump locations per word), while introducing global overhead of roughly 24 bytes / 40 gas.

Drawbacks

Alternatives

If the drawbacks are too great, we can look into alternative strategies which would still return most of the benefits (i.e. O(1) or O(log(n)) entry point calculation but maybe less space / gas efficient)

Helpful Links

https://en.wikipedia.org/wiki/Branch_table#Jump_table_example_in_C https://en.wikipedia.org/wiki/Perfect_hash_function https://en.wikipedia.org/wiki/Cuckoo_hashing

Copyright

Copyright and related rights waived via CC0

fubuloubu commented 3 years ago
  • It's difficult to detect invalid method identifiers since the identifiers are no longer stored in the code. The fallback function would have to be a special case. For the same reason, it wouldn't work well for common switch/case statements which have a default clause.

Does this matter? As long as it's a guarantee that any other method ID triggers the fallback, I think this is okay

  • It's hard to disassemble code since identifiers are no longer in the code. All that's left of them is the MAGIC_NUMBER that is calculated at compile-time.

Can you backsolve to obtain the set of method IDs that work based on MAGIC_NUMBER?

charles-cooper commented 3 years ago

Does this matter? As long as it's a guarantee that any other method ID triggers the fallback, I think this is okay

We wouldn't have a guarantee that other method IDs trigger the fallback. (Unless we found an amazing hash function that somehow magically sends our N method IDs to valid lookup table locations and every other possible method ID to the fallback location). Instead, if method IDs are invalid they would more likely just jump to some random method and the behavior is undefined.

But realistically speaking, if a caller wants to trigger the fallback they should send no message data. So we can handle that with a simple if (iszero (calldatalength)) goto fallback at the beginning of the contract. On the other hand if the method ID is just garbage, I don't really have a problem with "not playing nice" in that case.

Can you backsolve to obtain the set of method IDs that work based on MAGIC_NUMBER?

Not really, I mean the problem is you would be reverse mapping our N method IDs to the entire space of possible method IDs. So there could be practically infinitely many sets of method IDs which work with MAGIC_NUMBER. It's like reversing a hash function: it's computationally difficult to find a single match, and even if you do, there are infinitely many strings which could map to a given output.

fubuloubu commented 3 years ago

On the other hand if the method ID is just garbage, I don't really have a problem with "not playing nice" in that case.

Is there a way to exploit this behavior to start execution at a random section of code you prefer?

Not really, I mean the problem is you would be reverse mapping our N method IDs to the entire space of possible method IDs. So there could be practically infinitely many sets of method IDs which work with MAGIC_NUMBER. It's like reversing a hash function: it's computationally difficult to find a single match, and even if you do, there are infinitely many strings which could map to a given output.

This is an interesting property, not sure how I feel about it.

charles-cooper commented 3 years ago

Is there a way to exploit this behavior to start execution at a random section of code you prefer?

I don't really see an exploit, but I could just be naive. Thing is, the only JUMPDESTs you could access would be the existing method starts since those are the only ones in the jumptable. So .. you couldn't trigger anything that you wouldn't have been able to trigger by passing a valid method ID to begin with.

fubuloubu commented 3 years ago

but could you jumpdest to an internal method? or some other equally contrived but serious outcome?

charles-cooper commented 3 years ago

but could you jumpdest to an internal method? or some other equally contrived but serious outcome?

You could only jump to locations which are hardcoded in the jump table so I'm not too concerned about this

charles-cooper commented 3 years ago

Notes from meeting: We can't erase the inputs from the code because sometimes folks will provide non-null calldata to the fallback function

charles-cooper commented 2 years ago

So we can implement this using a hash table with probing. Instead of the jumptable only being jumpdest1 jumpdest2 ..., we include the inputs in the jumptable

val1 jumpdest1 val2 jumpdest2

For a total of 6 bytes overhead per case.

To handle collisions, we check if input == val1. The entries in the jumptable which collide form a linked list of sorts. To reach the next element of the list, we rehash the input. Since a jumpdest must be smaller than 24576 (that is, it fits in 15 bits), we can use the top bit to encode whether we have reached the end of the linked list.

The pseudocode (C-esque; I will rewrite this in LLL at some point) to resolve the jumptable then is

SCRATCH_SPACE = 128; // FREE_VAR_SPACE
JUMPTABLE_ENTRY_SIZE = 6; // number of bytes a single entry takes
jumptable_offset = <some offset in the data section>;
method_id = calldataload(0) >> 224;
jumptable_entry = 0; // seed value for hash
while (true) {
    jumptable_entry = HASH(method_id, jumptable_entry) * JUMPTABLE_ENTRY_SIZE + jumptable_offset;
    codecopy(jumptable_ofst, SCRATCH_SPACE, 6);
    val = mload(SCRATCH_SPACE) >> 16;
    if (val == method_id) {
        jumpdest = mload(SCRATCH_SPACE) & 0x7FFF/*1<<15 - 1*/;
        goto jumpdest; // we are done, we found the right jumpdest
    }
    has_more_entries = mload(SCRATCH_SPACE) >> 15 & 1;
    if (!has_more_entries) {
        goto fallback; // we are done, unrecognized method_id
    }
}

A good HASH function could be hash(val, seed) = mulmod(val * (seed + 1), <a good prime>). Note that finding <a good prime> is a lot easier than the magic prime originally proposed. It could be any prime larger than all the inputs, which should be trivial for 4-byte method ids.

k06a commented 2 years ago

If we would assume hash function result random is good enough than it seems brute-forcing complexity will be: n^n/n! because there are n^n hash function results for n inputs and n! non-colliding results (basically number all permutations).

You will need to brute-force:

Splitting on 2-3 functions to process sequentially could help: compute h1(selector) to get 0..3 and then use h2(selector) to get 0..4, this could be assembled back in different 20 numbers by brute-forcing only 50 (= 24 + 26) combinations: h1(selector)*5 + h2(selector).

So we can split a number of selectors by factorizing it. Having a prime number of selectors is not good for us. Not sure what we can do except add extra empty slots with jumps to something like revert("selector not found"). But from the other side prime numbers better work for modulus.

⚠️ Following calculations are not fully correct, because this two-functions search is not independent events, but dependent. Results of these two hash functions should never clusterise any pair of selectors into the groups in both hashes, this would make them indistinguishable. Need to find a way to recompute probabilities for this case.

We could try to extend the number of selectors up to some percentage (like up to 5%) to find the best factorization numbers. By best factorization numbers I would define having 2 numbers close to sqrt(n), this would minimize a total number of combinations brute-forced. For example, you have 23 selectors, we can increase it by 1 and split 24 on two different factors 2*12 and 4*6:

brute(2) + brute(12) > brute(4) + brute(6)
2 + 479m > 24 + 720

Having 2 extra selectors would make 25=5*5, which can be computed with 120+120 combinations only, much lower even than 24+720.

k06a commented 2 years ago

Looks like the function mulmod(selector, magic, length) will never work properly for selectors a, b which have the same remainder a%len and b%len. Under modulus these selectors became indistinguishable. Maybe we could try to use keccak256 for this purpose, it will work for any lengths:

keccak256(abi.encodePacked(selector + (magic << 32))) % length
k06a commented 2 years ago

Works like a charm, spent 4486 combinations for 10 selectors: https://gist.github.com/k06a/737fd1f389bb6f2e66f0ef76ea73729b

fubuloubu commented 2 years ago
keccak256(abi.encodePacked(selector + (magic << 32))) % length

What's the estimated gas usage for this vs. current O(n) dispatch methods? (say for ~80 method IDs e.g. Yearn Vaults v2)

k06a commented 2 years ago

@fubuloubu we can find 1 magic number to split 80 selectors on 9 groups, then find 9 magic number to split each group. Gas cost will be 36+36 for two sequential hashing, I would assume in practice with all the necessary checks it will cost around 100 gas for every method.

k06a commented 2 years ago

Using mulmod directly is not possible, if selectorA % selectors.length == selectorB % selectors.length then using mulmod will give a collision.

charles-cooper commented 12 months ago

closing as completed as of https://github.com/vyperlang/vyper/pull/3496