solana-labs / rbpf

Rust virtual machine and JIT compiler for eBPF programs
Apache License 2.0
279 stars 172 forks source link

Extending pc integral trick to non-uniform op costs #185

Closed jon-chuang closed 2 years ago

jon-chuang commented 3 years ago

There seems to be the likelihood that certain ops, especially memory ops (i.e. reg load, jmp which requires a random instruction seek), might need to be charged higher amount of compute units, for the reason that they can be several times slower than an alu op (100ns RAM seek is ~300x slower than an i32 add). At 200_000 compute units, at 100ns per RAM operation, this would result in a program time of at least 20ms, which I think is a lot higher than desired.

This is a good explanation for Serum's bad CU/us ratio, with a median of 7, which is consistent with 100ns per opcode. That's because Serum's critbit tree data structure does a lot of random reads from a 1MB+memory region, which maybe it lives in L3 cache or even loaded cold from RAM - so dominated by opcodes 30-300x slower than ALU opcode.

In the model we consider, the majority of opcodes have a cost of 1, and they dominate.

For every expensive opcode, during emission, we increase the compute_meter by the opcode cost minus 1. Opcode extra costs are stored in an array, so one can index directly into it during instruction emission via the u8 insn.opc. The load and add instruction is not emitted when the extra cost is 0 (obviously).

At the very least, the cost of adding a constant value to the compute meter during execution, being an ALU operation, is at least comparable to the cost of an expensive operation. For instance, a memory operation will probably always have to be at least 30x CUs of an ALU, so the cost of loading the compute meter, probably living in the L1 cache, and adding a constant to it will consume far fewer cycles by comparison.

Another option is if the bpf can be analysed and segmented into basic blocks, thus reducing the need to deal with arbitrary jumps. (Edit: I guess this is ongoing work with static analysis).

For every ordinary opcode, we do nothing as per the constant 1 integral.

jon-chuang commented 3 years ago

Thankfully, static analysis is always possible as there are no arbitrary code jumps in eBPF - jumps are always to an inlined offset.

jon-chuang commented 3 years ago

@jackcmay @Lichtso , for the reason that static analysis may be expensive in the malicious or even non-malicious scenario, I think the above approach may be a good way to perform "intermediate JIT" in the non-uniform opcode cost case.

The static analysis could then only be applied to programs that live for a long time in the executor cache, as a one-time cost.

Rather than eagerly JITing+statically analysing every program.

So the compilation hierarchy:

  1. Interpreted
  2. Inline JIT (above approach)
  3. Static analysis "full" compilation
jon-chuang commented 3 years ago

Even if the static analysis is expensive, it would probably be possible to compile asynchronously in a background thread, and silently replace the JITed binary by acquiring a write lock to the cache.

Lichtso commented 3 years ago

Thankfully, static analysis is always possible as there are no arbitrary code jumps in eBPF - jumps are always to an inlined offset.

Unfortunately, we have ebpf::CALL_REG (callx instruction) which is a dynamic (unpredictable) jump.

May I know if a non-uniform opcode compute unit cost table is on the horizon?

Yes, I don't see why not. In fact it is a lot simpler than you think and works with the current approach out-of-the-box. We only have to add an explicit cumulative cost table instead of using the pc directly. There is no need to emit any extra x86 instructions for more expensive BPF instructions.

Lichtso commented 3 years ago

Here, I sketched a small python prototype to illustrate how it works:

import random

def generate_randomized_trace(cost_limit, cost_of):
    cumulative_cost = reduce(lambda c, x: c + [c[-1] + x], cost_of, [0])
    program_length = len(cost_of)
    interpreter_instruction_meter = 0
    jit_instruction_meter = cost_limit
    cost_limit_reached = False
    ip_pc = 0
    while True:
        if ip_pc + 1 >= program_length:
            print("Program reached the end")
            return
        interpreter_instruction_meter += cost_of[ip_pc]
        print("ip/pc={} interpreter={}/{} jit={}/{}".format(ip_pc,
            interpreter_instruction_meter, cost_limit,
            cumulative_cost[ip_pc + 1], jit_instruction_meter,
        ))
        interpreter_limit_reached = cost_limit <= interpreter_instruction_meter
        jit_limit_reached = jit_instruction_meter <= cumulative_cost[ip_pc + 1]
        if (interpreter_limit_reached or jit_limit_reached) and not cost_limit_reached:
            cost_limit_reached = True
            print("Cost limit reached by interpeter={} jit={}".format(
                interpreter_limit_reached, jit_limit_reached))
        if random.randint(0, 100) < 20:
            target_ip_pc = random.randint(0, program_length - 1)
            jit_instruction_meter -= cumulative_cost[ip_pc + 1]
            jit_instruction_meter += cumulative_cost[target_ip_pc]
            print("Jump to ip/pc={}".format(target_ip_pc))
            if cost_limit_reached:
                print("JIT aborted branch")
                return
            else:
                ip_pc = target_ip_pc
        else:
            ip_pc += 1

generate_randomized_trace(64, [random.randint(1, 3) for _ in range(0, 100)])

Which, for example generates the following trace:

ip/pc=0 interpreter=1/64 jit=1/64
ip/pc=1 interpreter=4/64 jit=4/64
ip/pc=2 interpreter=6/64 jit=6/64
ip/pc=3 interpreter=9/64 jit=9/64
ip/pc=4 interpreter=11/64 jit=11/64
Jump to ip/pc=8
ip/pc=8 interpreter=12/64 jit=17/69
Jump to ip/pc=63
ip/pc=63 interpreter=14/64 jit=128/178
ip/pc=64 interpreter=17/64 jit=131/178
Jump to ip/pc=89
ip/pc=89 interpreter=18/64 jit=177/223
ip/pc=90 interpreter=19/64 jit=178/223
ip/pc=91 interpreter=20/64 jit=179/223
ip/pc=92 interpreter=22/64 jit=181/223
Jump to ip/pc=45
ip/pc=45 interpreter=23/64 jit=90/131
ip/pc=46 interpreter=25/64 jit=92/131
ip/pc=47 interpreter=26/64 jit=93/131
ip/pc=48 interpreter=28/64 jit=95/131
Jump to ip/pc=69
ip/pc=69 interpreter=29/64 jit=142/177
ip/pc=70 interpreter=30/64 jit=143/177
ip/pc=71 interpreter=32/64 jit=145/177
ip/pc=72 interpreter=33/64 jit=146/177
Jump to ip/pc=29
ip/pc=29 interpreter=34/64 jit=59/89
ip/pc=30 interpreter=36/64 jit=61/89
ip/pc=31 interpreter=37/64 jit=62/89
ip/pc=32 interpreter=38/64 jit=63/89
ip/pc=33 interpreter=39/64 jit=64/89
ip/pc=34 interpreter=40/64 jit=65/89
ip/pc=35 interpreter=43/64 jit=68/89
ip/pc=36 interpreter=46/64 jit=71/89
ip/pc=37 interpreter=49/64 jit=74/89
ip/pc=38 interpreter=50/64 jit=75/89
ip/pc=39 interpreter=53/64 jit=78/89
ip/pc=40 interpreter=54/64 jit=79/89
ip/pc=41 interpreter=57/64 jit=82/89
Jump to ip/pc=47
ip/pc=47 interpreter=58/64 jit=93/99
ip/pc=48 interpreter=60/64 jit=95/99
ip/pc=49 interpreter=63/64 jit=98/99
ip/pc=50 interpreter=64/64 jit=99/99
Cost limit reached by interpeter=True jit=True
ip/pc=51 interpreter=67/64 jit=102/99
ip/pc=52 interpreter=68/64 jit=103/99
ip/pc=53 interpreter=71/64 jit=106/99
ip/pc=54 interpreter=73/64 jit=108/99
ip/pc=55 interpreter=76/64 jit=111/99
Jump to ip/pc=75
JIT aborted branch

So the only difference to how it currently is, is that cumulative_cost is not the identity function anymore, but the lookup can happen at compile time.

jon-chuang commented 3 years ago

Yep, that integral trick is very elegant. Still makes me worried about long straight-line programs.

Lichtso commented 3 years ago

Still makes me worried about long straight-line programs.

We could emit phony branches if we detect long linear runs, but these are not particularly useful anyway. And, yes I am aware that these programs are still turing complete and things like the movfuscator exist, but they also prove how limited this approach is as it bloats the program size tremendously.

Edit: Just remembered that the last return / exit counts as a branch too, so that would check if the program did run too long.

jon-chuang commented 3 years ago

@Lichtso my worry is not the detection per-se but rather to catch such programs in the act so they don't run for say 100ms rather than say a target 10ms.

There's actually a separate issue involved, which is that if the program throws an error or panics after it was supposed to have run out of budget, then it may return a different return type. I guess this wouldn't be a problem if every error or branch into an error also first checks the compute meter (perhaps already subsumed under the conditions for returning)

Lichtso commented 3 years ago

There's actually a separate issue involved, which is that if the program throws an error or panics after it was supposed to have run out of budget, then it may return a different return type

Interesting idea, might not be covered so far, but would have to write a test case for that. This is what we currently have: https://github.com/solana-labs/rbpf/blob/42ca251fdb77ce106a57b37f1ad353a0b0207209/src/jit.rs#L436