diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.41k stars 166 forks source link

rev-ad: bug after pass extract-kernels and simplify #1511

Closed coancea closed 3 years ago

coancea commented 3 years ago

The bug appears on file futhark-ad/rnns/lstm/lstm.fut compiled with futhark cuda lstm.fut (using the branch ad-genred-opt, but very likely also clean-ad given that it is likely produced by the pass extract kernels)

$ futhark cuda lstm.fut 
Internal compiler error.  Please report this:
  https://github.com/diku-dk/futhark/issues
Type error after pass 'simplify':
In function entry_main
When checking function body
In expression of statement
  {loop_adj_15511 : ({}, [n_15307][h_15309][bs_15306]f32),
   loop_adj_15512 : ({}, [h_15309][bs_15306]f32),
   loop_adj_15513 : ({}, [h_15309][bs_15306]f32),
   loop_adj_15514 : ({}, [hx4_15310][d_15308]f32),
   loop_adj_15515 : ({}, [hx4_15310][h_15309]f32),
   loop_adj_15516 : ({}, [hx4_15310]f32)}
Inside the loop body
In expression of statement
  {withacc_res_15671 : ({}, [bs_15306][h_15309]f32),
   withacc_res_15672 : ({}, [hx4_15310][h_15309]f32),
   withacc_res_15673 : ({x_adj_17996}, [hx4_15310]f32)}
Type error:
Variable acc_p_15680 referenced after being consumed.

I paste the program below (long):

type real= f32
let zero = 0f32
let sum  = f32.sum
let log  = f32.log
let tanh = f32.tanh
let exp  = f32.exp
let fromi64 = f32.i64

let dotproduct [n] (a: [n]real) (b: [n]real) : real =
    map2 (*) a b |> sum

let matvec [m][n] (mat: [m][n]real) (vec: [n]real) =
    map (dotproduct vec) mat

let matmul [m][n][q] (ass: [m][q]real) (bss: [q][n]real) : [m][n]real =
    map (matvec (transpose bss)) ass

let sigmoid (x: real) : real =
    1.0 / (1.0 + exp(-x))

let meanSqr [n]
            (y_y_hat : [n](real, real)) =
  let s  = map (\(a, b) -> (a - b) * (a - b)) y_y_hat
        |> sum
  in  s / (fromi64 n)

let step [bs] [hx4] [h] [d]
         (wght_ih: [hx4][d]real)
         (wght_hh: [hx4][h]real)
         (bias:    [hx4]real)
         (inp_els: [bs][d]real)
         (hidn_st: [h][bs]real, cell_st: [h][bs]real)
         : ([h][bs]real, [h][bs]real) =
  let mm_ih = map (matvec inp_els) wght_ih |> opaque
    -- map (matvec wght_ih) inp_els |> opaque

  let mm_hh = matmul wght_hh hidn_st
    -- map (matvec (transpose hidn_st)) wght_hh |> opaque
    -- map (matvec wght_hh) hidn_st |> opaque

  let gates = map2 (map2 (+)) mm_ih mm_hh
           |> map2 (\b row -> map (+b) row) bias

  let gates'     = assert (4*h == hx4)
                          (unflatten 4 h gates)
  let ingate     = map (map sigmoid) (gates'[0])
  let forgetgate = map (map sigmoid) (gates'[1])
  let cellgate   = map (map tanh   ) (gates'[2])
  let outgate    = map (map sigmoid) (gates'[3])

  let cell_st' =  map2 (map2 (*)) ingate cellgate
               |> map2 (map2 (+))
                       (map2 (map2 (*)) forgetgate cell_st)
  let hidn_st' = map (map tanh) cell_st' |> map2 (map2 (*)) outgate

  in  (hidn_st', cell_st')

let lstmPrd [bs][n][d][h][hx4]
            (input: [n][bs][d]real)
            (wght_ih: [hx4][d]real)
            (wght_hh: [hx4][h]real)
            (wght_y:    [h][d]real)
            (bias:       [hx4]real)
            (bias_y:       [d]real)
            (hidn_st0: [h][bs]real)
            (cell_st0: [h][bs]real)
          : ([n][bs][d]real, [n][bs][h]real, [h][bs]real) =
  -- rnn component
  let hidn_stack0  = replicate bs zero
                  |> replicate h
                  |> replicate n

  -- hidn_stack0 :: [n][bs][h]
  let (hidn_stack, (_, cell_st)) =
    loop (hidn_stack, (hidn_st, cell_st)) = (hidn_stack0, (hidn_st0, cell_st0))
    for i < n do
        let (hidn_st', cell_st') = step wght_ih wght_hh bias input[i] (hidn_st, cell_st)
        let hidn_stack[i] = hidn_st'
        in  (hidn_stack, (hidn_st', cell_st'))

  -- fully connected output
  let hidn_stack'  = hidn_stack
                  |> map transpose

  let y_hat  = matmul (flatten hidn_stack') wght_y
            |> map (map2 (+) bias_y)
            |> unflatten n bs

  in  (y_hat, hidn_stack', cell_st)
  -- hidden_states[:, h-1] instead of hidn_stack in the return?

----------------------------------------------------------------
-- `bs`  is the batch size (for the moment `bs = 1`)          --
-- `n`   is the length of a time series;                      --
-- `d`   is the dimensionality of a point of a time series;   --
-- `h`   is the length of the hidden layer                    --
-- `hx4` is `4 x h`                                           --
-- `layers`: not used, i.e., we assume num layers is 1        --
----------------------------------------------------------------
let lstmObj [bs][n][d][h][hx4]
            (input: [n][bs][d]real)
            (hidn_st0: [h][bs]real)
            (cell_st0: [h][bs]real)
            ( wght_ih: [hx4][d]real
            , wght_hh: [hx4][h]real
            , wght_y:    [h][d]real
            , bias:       [hx4]real
            , bias_y:       [d]real
            )
          : real =
  let (input_hat, _, _) =
        lstmPrd input wght_ih wght_hh wght_y bias bias_y hidn_st0 cell_st0
  let y_y_hat  = map2 (map2 zip) input_hat input
              |> flatten
              |> flatten
  let loss = meanSqr y_y_hat
  in  loss

let main [bs][n][d][h][hx4]
         (input: [n][bs][d]real)
         (hidn_st0: [h][bs]real)
         (cell_st0: [h][bs]real)
         --- to-diff params
         (wght_ih: [hx4][d]real)
         (wght_hh: [hx4][h]real)
         (wght_y:    [h][d]real)
         (bias_h:    [hx4]real)
         (bias_y:       [d]real)
         --- adjoints ---
         (loss_adj : real) :
         ( [hx4][d]real
         , [hx4][h]real
         , [h][d]real
         , [hx4]real
         , [d]real
         ) =
  vjp (lstmObj input hidn_st0 cell_st0)
      (wght_ih, wght_hh, wght_y, bias_h, bias_y) loss_adj
athas commented 3 years ago

I suspect this is the same issue as #1505. I have not merged master into clean-ad since fixing it.

athas commented 3 years ago

Yes, fixed with a merge.