diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.41k stars 166 forks source link

Bug in pass ad (related to accumulators) #1515

Closed coancea closed 2 years ago

coancea commented 2 years ago

Bug is:

$ futhark dev --gpu step.fut > step.ir -v
[  +0.000037] Reading and type-checking source program
[  +0.155007] Defunctorising
[  +0.000759] Monomorphising
[  +0.024246] Lifting lambdas
[  +0.001207] Defunctionalising
[  +0.024621] Converting to core IR
[  +0.021020] Type-checking internalised program
[  +0.010557] Running pass simplify
[  +0.023221] Running pass Inline conservatively
[  +0.050901] Running pass simplify
[  +0.002154] Running pass Inline aggressively
[  +0.003308] Running pass simplify
[  +0.002508] Running pass CSE
[  +0.000325] Running pass simplify
[  +0.001785] Running pass Fuse SOACs
[  +0.006300] Running pass CSE
[  +0.000299] Running pass simplify
[  +0.002396] Running pass Fuse SOACs
[  +0.008510] Running pass CSE
[  +0.000468] Running pass simplify
[  +0.001112] Running pass Remove dead functions
[  +0.000297] Running pass ad
Internal compiler error.  Please report this:
  https://github.com/diku-dk/futhark/issues
Internal compiler error (unhandled IO exception).
Please report this at https://github.com/diku-dk/futhark/issues
Invalid slice for accumulator update: [i_12152, 0i64 :+ bs_11935 * 1i64]
CallStack (from HasCallStack):
  error, called at src/Futhark/AD/Rev/Monad.hs:367:9 in futhark-0.21.0-IdT9ZmWZph2HTE2hcJfaJe:Futhark.AD.Rev.Monad

Input program is:

type real= f32
let zero = 0f32
let sum  = f32.sum
let log  = f32.log
let tanh = f32.tanh
let exp  = f32.exp
let fromi64 = f32.i64

let dotproduct [n] (a: [n]real) (b: [n]real) : real =
    map2 (*) a b |> sum

let matvec [m][n] (mat: [m][n]real) (vec: [n]real) =
    map (dotproduct vec) mat

let matmul [m][n][q] (ass: [m][q]real) (bss: [q][n]real) : [m][n]real =
    map (matvec (transpose bss)) ass

let sigmoid (x: real) : real =
    1.0 / (1.0 + exp(-x))

let mkGate [hx4][bs]
       (h: i64) (beg: i64)
       (mm_ih: [hx4][bs]real)
       (mm_hh: [hx4][bs]real)
       (bias:  [hx4]real) : [h][bs]real =
  let all = zip3 mm_ih mm_hh bias
  in  iota h |>
      map (\ i ->
            let (row_ih, row_hh, b) = all[beg+i]
            in  map2 (+) row_ih row_hh |> map (+b) 
          ) 

let step [bs] [hx4] [h] [d]
         (inp_els: [bs][d]real)
         (hidn_st: [h][bs]real, cell_st: [h][bs]real)
         -- weights
         ( wght_ih: [hx4][d]real
         , wght_hh: [hx4][h]real
         , bias:    [hx4]real
         )
         : ([h][bs]real, [h][bs]real) =
  let mm_ih = map (matvec inp_els) wght_ih |> opaque

  let mm_hh = matmul wght_hh hidn_st |> opaque

    let ingate0     = mkGate h 0     mm_ih mm_hh bias
    let forgetgate0 = mkGate h h     mm_ih mm_hh bias
    let cellgate0   = mkGate h (2*h) mm_ih mm_hh bias
    let outgate0    = mkGate h (3*h) mm_ih mm_hh bias

  let ingate     = map (map sigmoid) ingate0
  let forgetgate = map (map sigmoid) forgetgate0
  let cellgate   = map (map tanh   ) cellgate0
  let outgate    = map (map sigmoid) forgetgate0

  let cell_st' =  map2 (map2 (*)) ingate cellgate
               |> map2 (map2 (+))
                       (map2 (map2 (*)) forgetgate cell_st)
  let hidn_st' = map (map tanh) cell_st' |> map2 (map2 (*)) outgate

  in  (hidn_st', cell_st')

-- ==
-- compiled random input { [1024][80]f32 [256][1024]f32 [256][1024]f32 [1024][80]f32 [1024][256]f32 [1024]f32 [256][1024]f32 [256][1024]f32 } auto output

let main [bs] [hx4] [h] [d]
         (inp_els: [bs][d]real)
         (hidn_st: [h][bs]real)
         (cell_st: [h][bs]real)
         -- weights
         (wght_ih: [hx4][d]real)
         (wght_hh: [hx4][h]real)
         (bias:    [hx4]real)
         -- adjoints
         (hidn_st_adj: [h][bs]real)
         (cell_st_adj: [h][bs]real) =
  vjp (step inp_els (hidn_st, cell_st))
      (wght_ih, wght_hh, bias)
      (hidn_st_adj, cell_st_adj)
athas commented 2 years ago

Works fine for me on clean-ad.