AnyDSL / thorin

The Higher-Order Intermediate Representation
https://anydsl.github.io
GNU Lesser General Public License v3.0
151 stars 15 forks source link

Partial evaluation/hashing bug #62

Closed madmann91 closed 7 years ago

madmann91 commented 7 years ago

There seems to be some sort of bug in the hashing/partial evailuation:

impala: /space/perard/sources/anydsl/thorin/src/thorin/util/hash.h:113: void thorin::HashTable<Key, T, H>::iterator_base<is_const>::verify() const [with bool is_const = true; Key = thorin::Use; T = void; H = thorin::UseHash]: Assertion `table_->id_ == id_' failed.

This bug is triggered by the following piece of code, with -emit-llvm. Interestingly, if you replace let iterate = |begin, end, body| { vectorize(4, begin, end, body); } by let iterate = range; (in the function traverse_single), then the code compiles perfectly. Moving the partial evaluation symbols also triggers another assertion.

impala: /space/perard/sources/anydsl/thorin/src/thorin/transform/partial_evaluation.cpp:135: void thorin::PartialEvaluator::eval(thorin::Continuation*, thorin::Continuation*): Assertion `ncur != nullptr' failed.

This is the culprit:

// Iteration function
type IterateFn = fn (i32, i32, fn(i32) -> ()) -> ();

extern "thorin" {
    fn vectorize(i32, i32, i32, fn(i32) -> ()) -> ();
    fn bitcast[D, S](S) -> D;
    fn select[A, B](A, B, B) -> B;
}

static flt_max = 1.0e+37f;

fn range(a: i32, b: i32, body: fn(i32) -> ()) -> () {
    if a < b {
        body(a);
        range(a + 1, b, body, return)
    }
}

fn unroll(a: i32, b: i32, body: fn(i32) -> ()) -> () @{
    if a < b @{
        body(a);
        unroll(a + 1, b, body, return)
    }
}

// Vector of size 3
struct Vec3 {
    x: f32, y: f32, z: f32
}

// Node for a 4-ary BVH
struct Node {
    // Minimum bounding box coord. for 4 children
    min_x: [f32 * 4],
    min_y: [f32 * 4],
    min_z: [f32 * 4],

    // Maximum bounding box coord. for 4 children
    max_x: [f32 * 4],
    max_y: [f32 * 4],
    max_z: [f32 * 4],

    // Child index (>0: inner node, 0: disabled, <0: leaf)
    child: [i32 * 4]
}

// Flattened triangle
struct FlatTri {
    // Packed x coords. for v0, e1, e2 and normal
    v0_x: [f32 * 4],
    e1_x: [f32 * 4],
    e2_x: [f32 * 4],
     n_x: [f32 * 4],

    // Packed y coords. for v0, e1, e2 and normal
    v0_y: [f32 * 4],
    e1_y: [f32 * 4],
    e2_y: [f32 * 4],
     n_y: [f32 * 4],

    // Packed z coords. for v0, e1, e2 and normal
    v0_z: [f32 * 4],
    e1_z: [f32 * 4],
    e2_z: [f32 * 4],
     n_z: [f32 * 4],

    // Index of the triangle (<0: sentinel, >=0: valid triangle index)
    id:   [i32 * 4]
}

// Manual stack for the traversal
struct Stack {
    swap: fn(i32) -> (),
    push: fn(i32) -> (),
    pop:  fn() -> i32,
    is_empty: fn() -> bool
}

// Structure that holds the hit information
struct Hit {
    t:  f32,
    u:  f32,
    v:  f32,
    id: i32
}

// Structure that holds a ray
struct Ray {
    org:   Vec3,
    dir:   Vec3,
    oidir: Vec3,
    idir:  Vec3,
    tmin:  f32,
    tmax:  f32
}

fn vec3(x: f32, y: f32, z: f32) -> Vec3 {
    Vec3 { x: x, y: y, z: z }
}

fn vec3_sub(a: Vec3, b: Vec3) -> Vec3 {
    Vec3 {
        x: a.x - b.x,
        y: a.y - b.y,
        z: a.z - b.z
    }
}

fn vec3_mul(a: Vec3, b: Vec3) -> Vec3 {
    Vec3 {
        x: a.x * b.x,
        y: a.y * b.y,
        z: a.z * b.z
    }
}

fn vec3_cross(a: Vec3, b: Vec3) -> Vec3 {
    Vec3 {
        x: a.y * b.z - a.z * b.y,
        y: a.z * b.x - a.x * b.z,
        z: a.x * b.y - a.y * b.x
    }
}

fn vec3_dot(a: Vec3, b: Vec3) -> f32 {
    a.x * b.x + a.y * b.y + a.z * b.z
}

fn ray(org: Vec3, dir: Vec3, tmin: f32, tmax: f32) -> Ray {
    let idir = vec3(1.0f / dir.x, 1.0f / dir.y, 1.0f / dir.z);
    let oidir = vec3_mul(org, idir);
    Ray {
        org: org,
        dir: dir,
        idir: idir,
        oidir: oidir,
        tmin: tmin,
        tmax: tmax
    }
}

fn alloc_stack() -> Stack {
    let sentinel = 0x7FFFFFFF;
    let mut top : i32 = sentinel;
    let mut items : [i32 * 64];
    let mut ptr = 0;
    Stack {
        swap: |node| { items(ptr++) = node; },
        push: |node| { items(ptr++) = top; top = node; },
        pop:  || { let old = top; top = items(--ptr); old },
        is_empty: || { top == sentinel }
    }
}

fn is_leaf(node: i32) -> bool { node < 0 }

fn iminf(a: f32, b: f32) -> f32 {
    // Use integer comparison
    /*let (a_, b_) = (bitcast[i32](a), bitcast[i32](b));
    bitcast[f32](select(a_ < b_, a_, b_))*/
    select(a < b, a, b)
}
fn imaxf(a: f32, b: f32) -> f32 {
    // Use integer comparison
    /*let (a_, b_) = (bitcast[i32](a), bitcast[i32](b));
    bitcast[f32](select(a_ > b_, a_, b_))*/
    select(a > b, a, b)
}
fn iminminf(a: f32, b: f32, c: f32) -> f32 { iminf(iminf(a, b), c) }
fn iminmaxf(a: f32, b: f32, c: f32) -> f32 { imaxf(iminf(a, b), c) }
fn imaxminf(a: f32, b: f32, c: f32) -> f32 { iminf(imaxf(a, b), c) }
fn imaxmaxf(a: f32, b: f32, c: f32) -> f32 { imaxf(imaxf(a, b), c) }

fn fabsf(x: f32) -> f32 { if x < 0.0f { -x } else { x } }

fn prodsign(x: f32, y: f32) -> f32 {
    bitcast[f32](bitcast[i32](x) ^ (bitcast[i32](y) & bitcast[i32](0x80000000u)))
}

fn intersect_ray_box(bmin: Vec3, bmax: Vec3, oidir: Vec3, idir: Vec3, tmin: f32, tmax: f32, intr: fn(f32, f32) -> ()) -> () @{
    fn span_begin(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32) -> f32 {
        imaxmaxf(iminf(a, b), iminf(c, d), iminmaxf(e, f, g))
    }

    fn span_end(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32) -> f32 {
        iminminf(imaxf(a, b), imaxf(c, d), imaxminf(e, f, g))
    }

    let t0_x = bmin.x * idir.x - oidir.x;
    let t1_x = bmax.x * idir.x - oidir.x;
    let t0_y = bmin.y * idir.y - oidir.y;
    let t1_y = bmax.y * idir.y - oidir.y;
    let t0_z = bmin.z * idir.z - oidir.z;
    let t1_z = bmax.z * idir.z - oidir.z;

    let t0 = span_begin(t0_x, t1_x, t0_y, t1_y, t0_z, t1_z, tmin);
    let t1 = span_end  (t0_x, t1_x, t0_y, t1_y, t0_z, t1_z, tmax);

    if (t0 <= t1) { intr(t0, t1) }
}

fn intersect_ray_tri(v0: Vec3, e1: Vec3, e2: Vec3, n: Vec3, org: Vec3, dir: Vec3, tmin: f32, tmax: f32, intr: fn(f32, f32, f32) -> ()) -> () @{
    let c = vec3_sub(v0, org);
    let r = vec3_cross(dir, c);
    let det = vec3_dot(n, dir);
    let abs_det = fabsf(det);

    let u = prodsign(vec3_dot(r, e2), det);
    let v = prodsign(vec3_dot(r, e1), det);
    let w = abs_det - u - v;

    if u >= 0.0f & v >= 0.0f & w >= 0.0f {
        let t = prodsign(vec3_dot(n, c), det);
        if t >= abs_det * tmin & abs_det * tmax >= t {
            let inv_det = 1.0f / abs_det;
            intr(t * inv_det, u * inv_det, v * inv_det);
        }
    }
}

fn traverse(iterate: IterateFn, stack: Stack, nodes: &[Node * 1], tris: &[FlatTri * 1], ray: Ray) -> Hit @{
    let mut t_hit = ray.tmax;
    let mut u_hit = 0.0f;
    let mut v_hit = 0.0f;
    let mut id_hit = -1;

    while !stack.is_empty() {
        let mut node_id = stack.pop();
        let node = nodes(node_id);

        // Intersect each child
        let mut found = [false, false, false, false];
        let mut entry : [f32 * 4]; 
        for i in iterate(0, 4) @{
            let min = vec3(node.min_x(i), node.min_y(i), node.min_z(i));
            let max = vec3(node.max_x(i), node.max_y(i), node.max_z(i));
            with t0, t1 in intersect_ray_box(min, max, ray.oidir, ray.idir, ray.tmin, t_hit) @{
                found(i) = node.child(i) != 0;
                entry(i) = t0;
            }
        }

        // "Sort" them
        let mut tmin = flt_max;
        for i in @unroll(0, 4) {
            if found(i) {
                if tmin > entry(i) @{
                    stack.push(node.child(i));
                    tmin = entry(i)
                } else @{
                    stack.swap(node.child(i))
                }
            }
        }

        while is_leaf(node_id) @{
            let mut tri_id = !node_id;
            while true {
                let cur = tris(tri_id);

                for i in iterate(0, 4) @{
                    /*let v0 = vec3(cur.v0_x(i), cur.v0_y(i), cur.v0_z(i));
                    let e1 = vec3(cur.e1_x(i), cur.e1_y(i), cur.e1_z(i));
                    let e2 = vec3(cur.e2_x(i), cur.e2_y(i), cur.e2_z(i));
                    let n  = vec3(cur. n_x(i), cur. n_y(i), cur. n_z(i));

                    with t, u, v in intersect_ray_tri(v0, e1, e2, n, ray.org, ray.dir, ray.tmin, t_hit) @{
                        t_hit = t;
                        u_hit = u;
                        v_hit = v;
                        id_hit = cur.id(i);
                    }*/
                }

                if cur.id(3) < 0 { break() }
                tri_id++;
            }

            node_id = stack.pop()
        }
    }

    Hit { t: t_hit, u: u_hit, v: v_hit, id: id_hit }
}

fn traverse_single(nodes: &[Node * 1], tris: &[FlatTri * 1], ray: Ray) -> Hit @{
    let iterate = |begin, end, body| { vectorize(4, begin, end, body) };
    let stack = alloc_stack();
    stack.push(0);
    traverse(iterate, stack, nodes, tris, ray)
}

extern fn from_c_traverse_single(nodes: &[Node * 1], tris: &[FlatTri * 1], org: &Vec3, dir: &Vec3, tmin: f32, tmax: f32, hit: &Hit) -> () @{
    *hit = traverse_single(nodes, tris, ray(*org, *dir, tmin, tmax));
}
leissa commented 7 years ago

I won't have time to debug today, but I guess the bug is sth like this:

hashmap[key] = do_sth_which_provokes_a_rehash(foo);

The fix would be:

auto x = do_sth_which_provokes_a_rehash(foo);
hashmap[key] = x;

Can you give me a back trace?

madmann91 commented 7 years ago

Fixed.