Closed Munksgaard closed 1 year ago
It's dead code; the compiler is not guaranteed to preserve nontermination.
My mistake, I simplified too much. Here is the real code which, as far as I can tell, does not have dead code:
-- Code and comments based on
-- https://github.com/kkushagra/rodinia/blob/master/openmp/nw
--
-- ==
-- entry: main_2d
-- compiled random input { 10i32 [129][129]i32 [129][129]i32 }
-- compiled random input { 10i32 [1025][1025]i32 [1025][1025]i32 }
-- compiled random input { 10i32 [1537][1537]i32 [1537][1537]i32 }
-- ==
-- compiled script input { mk_input 128i64 }
-- compiled script input { mk_input 1024i64 }
-- compiled script input { mk_input 1536i64 }
-- compiled script input { mk_input 8192i64 }
-- compiled script input { mk_input 16384i64 }
-- compiled script input { mk_input 32768i64 }
entry mk_input (row_length: i64) =
let n = (row_length + 1)**2
in (10i32, replicate n 0i32 , replicate n 0i32)
import "intrinsics"
let mkVal [bp1][b] (y:i32) (x:i32) (pen:i32) (block:[bp1][]i32) (ref:[b][b]i32) : i32 =
#[unsafe]
i32.max (block[y, x - 1] - pen) (block[y - 1, x] - pen)
|> i32.max (block[y - 1, x - 1] + ref[y - 1, x - 1])
let process_block [b][bp1]
(penalty: i32)
(above: [bp1]i32)
(left: [b]i32)
(ref: [b][b]i32): *[b][b]i32 =
let block = assert (b + 1 == bp1) (tabulate_2d bp1 (bp1+1) (\_ _ -> 0))
let block[0, 0:bp1] = above
let block[1:, 0] = left
let ref_block = copy (opaque ref)
let ref_block[0,0] = opaque ref_block[0,0]
let ref = ref_block
-- Process the first half (anti-diagonally) of the block
let block =
loop block for m < b do
let inds =
tabulate b (\tx ->
if tx > m then (-1, -1)
else let ind_x = i32.i64 (tx + 1)
let ind_y = i32.i64 (m - tx + 1)
in (i64.i32 ind_y, i64.i32 ind_x))
let vals =
-- tabulate over the m'th anti-diagonal before the middle
tabulate b
(\tx ->
if tx > m then 0
else let ind_x = i32.i64 (tx + 1)
let ind_y = i32.i64 (m - tx + 1)
let v = mkVal ind_y ind_x penalty block ref
in v)
in scatter_2d block inds vals
-- Process the second half (anti-diagonally) of the block
let block = loop block for m < b-1 do
let m = b - 2 - m
let inds = tabulate b (\tx -> (
if tx > m then (-1, -1)
else let ind_x = i32.i64 (tx + b - m)
let ind_y = i32.i64 (b - tx)
in ((i64.i32 ind_y, i64.i32 ind_x)) )
)
let vals =
-- tabulate over the m'th anti-diagonal after the middle
tabulate b (\tx -> (
if tx > m then (0)
else let ind_x = i32.i64 (tx + b - m)
let ind_y = i32.i64 (b - tx)
let v = mkVal ind_y ind_x penalty block ref
in v ))
in scatter_2d block inds vals
in block[1:, 1:bp1] :> *[b][b]i32
def main [n]
(penalty: i32)
(input: *[n]i32)
(refs: [n]i32)
: *[n]i32 =
let block_size = 64
let row_length = i64.f64 <| f64.sqrt <| f64.i64 n
let num_blocks = assert ((row_length - 1) % block_size == 0)
((assert (row_length > block_size * 3) row_length - 1) / block_size)
let bp1 = assert (row_length > 3) (assert (2 * block_size < row_length) (block_size + 1))
let input =
loop input for i < num_blocks do
let ip1 = i + 1
let v =
#[incremental_flattening(only_intra)]
map3 (process_block penalty)
(flat_index_2d input (i * block_size)
ip1 (row_length * block_size - block_size)
bp1 1)
(flat_index_2d input (row_length + i * block_size)
ip1 (row_length * block_size - block_size)
block_size row_length)
(flat_index_3d refs (row_length + 1 + i * block_size)
ip1 (row_length * block_size - block_size)
block_size row_length
block_size 1i64)
in flat_update_3d
input
(row_length + 1 + i * block_size)
(row_length * block_size - block_size)
(row_length)
1
v
let input =
loop input for i < num_blocks - 1 do
let v =
#[incremental_flattening(only_intra)]
map3 (process_block penalty)
(flat_index_2d input (((i + 1) * block_size + 1) * row_length - block_size - 1)
(num_blocks - i - 1) (row_length * block_size - block_size)
bp1 1i64)
(flat_index_2d input (((i + 1) * block_size + 1) * row_length - block_size - 1 + row_length)
(num_blocks - i - 1) (row_length * block_size - block_size)
block_size row_length)
(flat_index_3d refs (((i + 1) * block_size + 2) * row_length - block_size)
(num_blocks - i - 1) (row_length * block_size - block_size)
block_size row_length
block_size 1i64)
in flat_update_3d
input
(((i + 1) * block_size + 2) * row_length - block_size)
(row_length * block_size - block_size)
(row_length)
1
v
in input
entry main_2d [n] (penalty: i32) (input: *[n][n]i32) (refs: [n][n]i32) : *[n][n]i32 =
let k = n*n
in main penalty (flatten_to k input) (flatten_to k refs) |> unflatten n n
Compare:
$ futhark-0.21.2 test --backend=opencl ~/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut
/home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:
Compiling with --backend=opencl:
Running compiled program:
Running /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw:
Entry point: main_2d; dataset: 10i32 [129][129]i32 [129][129]i32:
Function failed with error:
Error: Assertion is false: (row_length > block_size * 3)
Backtrace:
-> #0 /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:92:29-75
#1 /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:146:6-58
#2 /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:144:1-146:75
Entry point: main; dataset: mk_input 128i64:
Function failed with error:
Error: Assertion is false: (row_length > block_size * 3)
Backtrace:
-> #0 /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:92:29-75
#1 /home/munksgaard/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut:84:1-142:10
┌──────────┬────────┬────────┬───────────┐
│ │ passed │ failed │ remaining │
├──────────┼────────┼────────┼───────────┤
│ programs │ 0 │ 1 │ 0/1 │
├──────────┼────────┼────────┼───────────┤
│ runs │ 4 │ 2 │ 0/6 │
└──────────┴────────┴────────┴───────────┘
$ futhark test --backend=opencl ~/src/futhark-mem-sc22/benchmarks/nw/futhark/nw.fut
┌──────────┬────────┬────────┬───────────┐
│ │ passed │ failed │ remaining │
├──────────┼────────┼────────┼───────────┤
│ programs │ 1 │ 0 │ 0/1 │
├──────────┼────────┼────────┼───────────┤
│ runs │ 6 │ 0 │ 0/6 │
└──────────┴────────┴────────┴───────────┘
Oh, and it depends on this intrinsics.fut
:
def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a =
intrinsics.flat_index_2d(as, offset, n1, s1, n2, s2) :> [n1][n2]a
def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a =
intrinsics.flat_update_2d(as, offset, s1, s2, asss)
def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a =
intrinsics.flat_index_3d(as, offset, n1, s1, n2, s2, n3, s3) :> [n1][n2][n3]a
def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a =
intrinsics.flat_update_3d(as, offset, s1, s2, s3, asss)
def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a =
intrinsics.flat_index_4d(as, offset, n1, s1, n2, s2, n3, s3, n4, s4) :> [n1][n2][n3][n4]a
def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a =
intrinsics.flat_update_4d(as, offset, s1, s2, s3, s4, asss)
Uh, this should be an issue on diku-dk/futhark
In contrast: