avanhatt / wasmtime

Standalone JIT-style runtime for WebAssembly, using Cranelift
https://wasmtime.dev/
Apache License 2.0
0 stars 1 forks source link

Verify clz #15

Closed mpardesh closed 1 year ago

mpardesh commented 1 year ago

This PR implements verification for clz, which counts the number of leading zeros in a bv.

Hacks to watch out for

How clz works Here's a runnable C program that computes clz:

#include <stdio.h>
#include <stdint.h>

uint8_t foo(uint64_t x);

int main()
{
    printf("%d\n", foo(16));
    return 0;
}

uint8_t foo(uint64_t x) {
    uint64_t y;
    uint8_t num_zeros = 0;

    y = x >> 32;
    if (y != 0) x = y; else num_zeros += 32;

    y = x >> 16;
    if (y != 0) x = y; else num_zeros += 16;

    y = x >> 8;
    if (y != 0) x = y; else num_zeros += 8;

    y = x >> 4;
    if (y != 0) x = y; else num_zeros += 4;

    y = x >> 2;
    if (y != 0) x = y; else num_zeros += 2;

    y = x >> 1;
    if (y != 0) x = y; else num_zeros += 1;

    if (x == 0) num_zeros += 1;

    return num_zeros;
}

And the translated SMT:

(declare-fun x () (_ BitVec 32))
(assert (= x (_ bv5 32))) ; placeholder to test x = 5

; total zeros counter
(declare-fun ret0 () (_ BitVec 32))
(assert (= ret0 (_ bv0 32)))

; round 1
(declare-fun ret1 () (_ BitVec 32))
(declare-fun y16 () (_ BitVec 32))
(declare-fun x16 () (_ BitVec 32))

(assert (= y16 (bvlshr x #x00000010)))
(assert (ite (not (= y16 (_ bv0 32))) (= ret1 ret0) (= ret1 (bvadd ret0 (_ bv16 32)))))
(assert (ite (not (= y16 (_ bv0 32))) (= x16 y16) (= x16 x)))

; round 2
(declare-fun ret2 () (_ BitVec 32))
(declare-fun y8 () (_ BitVec 32))
(declare-fun x8 () (_ BitVec 32))

(assert (= y8 (bvlshr x16 #x00000008)))
(assert (ite (not (= y8 (_ bv0 32))) (= ret2 ret1) (= ret2 (bvadd ret1 (_ bv8 32)))))
(assert (ite (not (= y8 (_ bv0 32))) (= x8 y8) (= x8 x16)))

; round 3
(declare-fun ret3 () (_ BitVec 32))
(declare-fun y4 () (_ BitVec 32))
(declare-fun x4 () (_ BitVec 32))

(assert (= y4 (bvlshr x8 #x00000004)))
(assert (ite (not (= y4 (_ bv0 32))) (= ret3 ret2) (= ret3 (bvadd ret2 (_ bv4 32)))))
(assert (ite (not (= y4 (_ bv0 32))) (= x4 y4) (= x4 x8)))

; round 4
(declare-fun ret4 () (_ BitVec 32))
(declare-fun y2 () (_ BitVec 32))
(declare-fun x2 () (_ BitVec 32))

(assert (= y2 (bvlshr x4 #x00000002)))
(assert (ite (not (= y2 (_ bv0 32))) (= ret4 ret3) (= ret4 (bvadd ret3 (_ bv2 32)))))
(assert (ite (not (= y2 (_ bv0 32))) (= x2 y2) (= x2 x4)))

; round 5
(declare-fun ret5 () (_ BitVec 32))
(declare-fun y1 () (_ BitVec 32))
(declare-fun x1 () (_ BitVec 32))

(assert (= y1 (bvlshr x2 #x00000001)))
(assert (ite (not (= y1 (_ bv0 32))) (= ret5 ret4) (= ret5 (bvadd ret4 (_ bv1 32)))))
(assert (ite (not (= y1 (_ bv0 32))) (= x1 y1) (= x1 x2)))

; last round
(declare-fun ret6 () (_ BitVec 32))
(assert (ite (not (= x1 (_ bv0 32))) (= ret6 ret5) (= ret6 (bvadd ret5 (_ bv1 32)))))

; final return
(declare-fun ret () (_ BitVec 32))
(assert (= ret ret6))

(check-sat)
(get-model)

And a Python script to generate Rust statements to add the SMT to the solver:

import sys
import re

filename = sys.argv[1]
decl = "(declare-fun "
assertion = "(assert "
pattern = re.compile(r'\{.*?\}')

# assume the line looks like (declare-fun <name> () <type>)
def parse_decl(line):
    name = line.split()[1]
    ty = f'String::from(\"{line.split("()")[1][1:-1]}\")'

    matches = re.findall(pattern, name)
    if len(matches) == 0:
        return name, ty

    var = set([m[1:-1] for m in matches])
    named_params = ', '.join([f'{x} = {x}' for x in var])
    return f'format!(\"{name}\", {named_params})', ty

# assume the line looks like (assert <assertion>)
def parse_assertion(line):
    a = line[len(assertion):-1]

    matches = re.findall(pattern, a)
    if len(matches) == 0:
        return a

    var = set([m[1:-1] for m in matches])
    named_params = ', '.join([f'{x} = {x}' for x in var])
    return f'format!(\"{a}\", {named_params})'

lines = []
with open(filename, 'r') as f:
    lines = f.readlines()

# this converter assumes self is accessible
for l in lines:
    line = l.strip()

    # leave blank lines
    if len(line) == 0:
        print("")
        continue

    # convert comments
    if line[0] == ';':
        print(f'//{line[1:]}')
        continue

    # convert declarations
    if line[:len(decl)] == decl:
        name, ty = parse_decl(line)
        print(f'self.additional_decls.push(({name}, {ty}));')
        continue

    # convert assertions
    if line[:len(assertion)] == assertion:
        a = parse_assertion(line)
        print(f'self.additional_assumptions.push({a});')
        continue
avanhatt commented 1 year ago

Overall, this is GREAT! Some comments inline but I think we might want to review this on a synchronous call.