google / xls

XLS: Accelerated HW Synthesis
http://google.github.io/xls/
Apache License 2.0
1.21k stars 179 forks source link

Miscompare between IR JIT and DSLX interpreter for unroll_for! #1686

Open mikex-oss opened 3 weeks ago

mikex-oss commented 3 weeks ago

Describe the bug xls_dslx_test with compare "jit" fails for the repro below with:

: internal error: INTERNAL: IR JIT produced a different value from the DSL interpreter for foofoo; IR JIT: bits[32]:0 DSL interpreter: bits[32]:511

I mistakenly thought this could be WAI, due to no ordering guarantees, but @grebe pointed out that that didn't make much sense with the accumulator value.

To Reproduce

fn foo() -> u32 {
    unroll_for! (j, result): (u32, u32) in range(u32:0, u32:10) {
        for (k, _): (u32, u32) in range(u32:0, std::upow(u32:2, j)) {
            k
        }(result)
    }(u32:0)
}

#[test]
fn foo_test() { assert_eq(foo(), u32:511); }

Expected behavior Both results should be bits[32]:511.

Additional context Converted IR shown below:

package foo

file_number 0 "xls/dslx/stdlib/std.x"
file_number 1 "foo.x"

fn ____std__upow__32_counted_for_0_body(i: bits[32] id=6, __loop_carry: (bits[32], bits[32], bits[32]) id=9) -> (bits[32], bits[32], bits[32]) {
  tuple_index.11: bits[32] = tuple_index(__loop_carry, index=0, id=11)
  literal.21: bits[32] = literal(value=1, id=21, pos=[(0,634,29)])
  literal.10: bits[1] = literal(value=1, id=10)
  literal.12: bits[1] = literal(value=1, id=12)
  and.22: bits[32] = and(tuple_index.11, literal.21, id=22, pos=[(0,634,27)])
  literal.23: bits[32] = literal(value=1, id=23, pos=[(0,634,41)])
  tuple_index.17: bits[32] = tuple_index(__loop_carry, index=2, id=17)
  tuple_index.14: bits[32] = tuple_index(__loop_carry, index=1, id=14)
  and.13: bits[1] = and(literal.10, literal.12, id=13)
  literal.15: bits[1] = literal(value=1, id=15)
  literal.27: bits[1] = literal(value=1, id=27, pos=[(0,636,14)])
  eq.24: bits[1] = eq(and.22, literal.23, id=24, pos=[(0,634,38)])
  umul.25: bits[32] = umul(tuple_index.17, tuple_index.14, id=25, pos=[(0,634,58)])
  literal.7: bits[32] = literal(value=0, id=7)
  and.16: bits[1] = and(and.13, literal.15, id=16)
  literal.18: bits[1] = literal(value=1, id=18)
  shrl.28: bits[32] = shrl(tuple_index.11, literal.27, id=28, pos=[(0,636,11)])
  umul.29: bits[32] = umul(tuple_index.14, tuple_index.14, id=29, pos=[(0,636,19)])
  result: bits[32] = sel(eq.24, cases=[tuple_index.17, umul.25], id=26, pos=[(0,634,21)])
  add.8: bits[32] = add(i, literal.7, id=8)
  and.19: bits[1] = and(and.16, literal.18, id=19)
  literal.20: bits[32] = literal(value=32, id=20, pos=[(0,629,12)])
  ret tuple.30: (bits[32], bits[32], bits[32]) = tuple(shrl.28, umul.29, result, id=30, pos=[(0,636,8)])
}

fn __std__upow__32(p: bits[32] id=1, n: bits[32] id=2) -> bits[32] {
  result: bits[32] = literal(value=1, id=4, pos=[(0,630,17)])
  tuple.5: (bits[32], bits[32], bits[32]) = tuple(n, p, result, id=5, pos=[(0,637,6)])
  work: (bits[32], bits[32], bits[32]) = counted_for(tuple.5, trip_count=32, stride=1, body=____std__upow__32_counted_for_0_body, id=31)
  N: bits[32] = literal(value=32, id=3, pos=[(0,629,12)])
  literal.32: bits[32] = literal(value=2, id=32, pos=[(0,638,9)])
  ret tuple_index.33: bits[32] = tuple_index(work, index=2, id=33, pos=[(0,638,8)])
}

fn ____foo__foo_counted_for_0_body(k: bits[32] id=35, __loop_carry: bits[32] id=38) -> bits[32] {
  literal.36: bits[32] = literal(value=0, id=36)
  add.37: bits[32] = add(k, literal.36, id=37)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=38)
}

fn ____foo__foo_counted_for_1_body(k: bits[32] id=40, __loop_carry: bits[32] id=43) -> bits[32] {
  literal.41: bits[32] = literal(value=0, id=41)
  add.42: bits[32] = add(k, literal.41, id=42)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=43)
}

fn ____foo__foo_counted_for_2_body(k: bits[32] id=45, __loop_carry: bits[32] id=48) -> bits[32] {
  literal.46: bits[32] = literal(value=0, id=46)
  add.47: bits[32] = add(k, literal.46, id=47)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=48)
}

fn ____foo__foo_counted_for_3_body(k: bits[32] id=50, __loop_carry: bits[32] id=53) -> bits[32] {
  literal.51: bits[32] = literal(value=0, id=51)
  add.52: bits[32] = add(k, literal.51, id=52)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=53)
}

fn ____foo__foo_counted_for_4_body(k: bits[32] id=55, __loop_carry: bits[32] id=58) -> bits[32] {
  literal.56: bits[32] = literal(value=0, id=56)
  add.57: bits[32] = add(k, literal.56, id=57)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=58)
}

fn ____foo__foo_counted_for_5_body(k: bits[32] id=60, __loop_carry: bits[32] id=63) -> bits[32] {
  literal.61: bits[32] = literal(value=0, id=61)
  add.62: bits[32] = add(k, literal.61, id=62)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=63)
}

fn ____foo__foo_counted_for_6_body(k: bits[32] id=65, __loop_carry: bits[32] id=68) -> bits[32] {
  literal.66: bits[32] = literal(value=0, id=66)
  add.67: bits[32] = add(k, literal.66, id=67)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=68)
}

fn ____foo__foo_counted_for_7_body(k: bits[32] id=70, __loop_carry: bits[32] id=73) -> bits[32] {
  literal.71: bits[32] = literal(value=0, id=71)
  add.72: bits[32] = add(k, literal.71, id=72)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=73)
}

fn ____foo__foo_counted_for_8_body(k: bits[32] id=75, __loop_carry: bits[32] id=78) -> bits[32] {
  literal.76: bits[32] = literal(value=0, id=76)
  add.77: bits[32] = add(k, literal.76, id=77)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=78)
}

fn ____foo__foo_counted_for_9_body(k: bits[32] id=80, __loop_carry: bits[32] id=83) -> bits[32] {
  literal.81: bits[32] = literal(value=0, id=81)
  add.82: bits[32] = add(k, literal.81, id=82)
  ret __loop_carry: bits[32] = param(name=__loop_carry, id=83)
}

top fn __foo__foo() -> bits[32] {
  result: bits[32] = literal(value=0, id=34, pos=[(1,7,6)])
  result__1: bits[32] = counted_for(result, trip_count=1, stride=1, body=____foo__foo_counted_for_0_body, id=39)
  result__2: bits[32] = counted_for(result__1, trip_count=2, stride=1, body=____foo__foo_counted_for_1_body, id=44)
  result__3: bits[32] = counted_for(result__2, trip_count=4, stride=1, body=____foo__foo_counted_for_2_body, id=49)
  result__4: bits[32] = counted_for(result__3, trip_count=8, stride=1, body=____foo__foo_counted_for_3_body, id=54)
  result__5: bits[32] = counted_for(result__4, trip_count=16, stride=1, body=____foo__foo_counted_for_4_body, id=59)
  result__6: bits[32] = counted_for(result__5, trip_count=32, stride=1, body=____foo__foo_counted_for_5_body, id=64)
  result__7: bits[32] = counted_for(result__6, trip_count=64, stride=1, body=____foo__foo_counted_for_6_body, id=69)
  result__8: bits[32] = counted_for(result__7, trip_count=128, stride=1, body=____foo__foo_counted_for_7_body, id=74)
  result__9: bits[32] = counted_for(result__8, trip_count=256, stride=1, body=____foo__foo_counted_for_8_body, id=79)
  ret counted_for.84: bits[32] = counted_for(result__9, trip_count=512, stride=1, body=____foo__foo_counted_for_9_body, id=84)
}