diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.41k stars 166 forks source link

Forward/reverse mode mismatch #1878

Closed athas closed 1 year ago

athas commented 1 year ago
def dficfj_test_ad (x : [8]f64) =
  #[unsafe] -- For simplicity of generated code.
  let col_w_pre_red =
    tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j])
  let col_w_red =
    map (map f64.sum) col_w_pre_red
  let col_eq : [4]f64 =
    map (\w -> w[0] - w[1]) col_w_red
  in col_eq

entry fwd_J (x : [8]f64)  =
  tabulate 8 (\ i ->
                jvp dficfj_test_ad x (replicate 8 0 with [i] = 1))

entry rev_J (x : [8]f64)  =
  transpose (tabulate 4 (\ i ->
                           vjp dficfj_test_ad x (replicate 4 0 with [i] = 1)))

Since reverse mode compiles this into a purely zero Jacobian, I'm betting it's reverse mode that is wrong.

athas commented 1 year ago

It is the sparsity optimisation that goes nuts. Probably it mishandles accumulator adjoints.