diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.37k stars 164 forks source link

AD Performance bug: probably related to dependency tracking in the presence of accumulators #1989

Open coancea opened 1 year ago

coancea commented 1 year ago

src.tar.gz

I apologize, but the compression is needed because github does not support .fut files.

Decompress and compile with

$ futhark dev -s driver-knn.fut > debug.txt

Examining the debug.txt file reveals two performance issues.

First Issue: dependency checking in the presence of accumulators

The first issue is critical and relates to the fact that the reverse trace of the (outer) loop at line 33 in driver-knn.fut file computes 10 adjoints, albeit only the last two are live after the loop and, importantly, their corresponding loop variants are NOT dependent on the previous 8 (loop-variant) adjoints. I suspect this is related to dependency tracking in the presence of accumulators being more conservative than it should be. The corresponding loop can be found starting at about line 599 in the debug.txt file. (I have verified independency of the last two adjoints from the rest by hand in the generated code)

I am assuming this because I have tested the case when the array variants are scalars or arrays, and they are simplified as expected --- in both of the cases presented below:

entry main (n: i32) (q : f32) =
  let (r1, u1, u2, u3, r2) =
    loop (a,x,y,z,b) = (q, 2*q, 3*q, 4*q, 5*q)
      for i < n do
        (b-a, (x+y)*z, y*(x+z), z*(x+y), a+b)
  in  (r1, r2)
entry main [n] (us1: *[n]i32) (us2: *[n]i32) (xs: *[n]i32) (ys: *[n]i32) : *[n]i32 =
  let (xs', us1', us2', ys') =
    loop (xs, us1, us2, ys)
      for i < n do
        let us1' = map2 (\ u1 u2 -> u1*u1 - u2*u2) us1 us2
        let us2[i:] = map2 (*) us1'[i:] us2[i:]
        let xs' = map2 (-) xs ys
        let ys' = map2 (+) xs ys
      in (xs', us1', us2, ys')
  in  map2 (+) xs' ys'

Second Issue: aggressive simplification rule for computing zeroes inside a loop

The attached code also reveals another issue that might bite us someday in the context of rev-AD code generation: the 8 loop-variant adjoints (that are dead code anyway) are also provably zeroes, but the simplifier is not yet smart enough to prove it. The problem steams from the fact that the loop-variant adjoints are initialized to zeroes and then are combined with other zeros -- but the fix-point is not discovered. I give a simple and illustrative example below that remains unsimplified for both loop variants, albeit it is easy to see that they are both zeros:

entry main [n] (bs: [n]bool) =
  let res =
    loop (x, y) = (0i32, false)
      for i < n do
        let y' = bs[i] && y
        let x' = x + (i32.bool y') 
        in  (x', y')
   in res

Since the zero intialization and subsequent updates are common to how rev-AD code generation works, I made an argument that maybe we should extend the simplifier to be more aggressive in recognizing the case when the loop variants are initially zeros or replicates of zeros and remain thus throughout the loop.

Intuitively the technique would be to extract the zero-based parameter initialization together with the body of the loop (but without the loop):

let x = 0i32
let y = false
let y' = bs[i] && y
let x' = x + (i32.bool y') 
in  (x', y')

and to run the simplifier on the resulted expression to a fix point, which, in our case, will hopefully result in (0, false). The results that are identical to zeros can be (hypothesize) safely eliminated from the loop variants and initialized to zeros before the loop.

athas commented 1 year ago

See here: https://github.com/diku-dk/futhark/pull/1990