Closed melsman closed 7 years ago
Came up with this reduced error case:
import "/futlib/linalg"
import "/futlib/math"
module array = import "/futlib/array"
module Linalg = linalg(f64)
let sigmoid (x:f64) =
1.0f64 / (1.0f64 + f64.exp(-x))
let sigmoid_prime (x:f64) =
let s = sigmoid x
in s * (1.0f64 - s)
let cost_derivative [n] (output_activations:[n]f64) (y:[n]f64) : [n]f64 =
map (-) output_activations y
let outer_prod [m][n] (a:[m]f64) (b:[n]f64): [m][n]f64 =
map (\x -> map (\y -> x * y) b) a
let main [i] [j] [k] (((b2: [j]f64,w2: [j][i]f64),(b3: [k]f64, w3: [k][j]f64))) (x:[i]f64,y:[k]f64) =
let z2 = map (+) (Linalg.matvecmul w2 x) b2
let z3 = map (+) (Linalg.matvecmul w3 z2) b3
let delta3 = map (*) (cost_derivative z3 y)
(map sigmoid_prime z3)
let nabla_b3 = delta3
let sp = map sigmoid_prime z2
let delta2 = map (*) (Linalg.matvecmul (array.transpose w3) delta3) sp
let nabla_w2 = outer_prod delta2 x
in (delta2,nabla_w2)
Fusion is by far the part of the compiler that I most dislike debugging, but I should be able to fix this.
Even smaller:
let dotprod [n] (xs: [n]f64) (ys: [n]f64): f64 =
reduce (+) 0.0 (map (*) xs ys)
let matvecmul [n] [m] (xss: [n][m]f64) (ys: [m]f64) =
map (dotprod ys) xss
let outer_prod [m][n] (a:[m]f64) (b:[n]f64): [m][n]f64 =
map (\x -> map (\y -> x * y) b) a
let main [i] [j] [k] (b2: [j]f64) (b3: [k]f64) (w3: [k][j]f64) (x:[i]f64) =
let delta2 = map (*) (matvecmul (rearrange (1,0) w3) b3) b2
let nabla_w2 = outer_prod delta2 x
in (delta2,nabla_w2)
A variant that actually showcases the reported bug (I hope; the above minimal case actually exhibited another bug):
let dotprod [n] (xs: [n]f64) (ys: [n]f64): f64 =
reduce (+) 0.0 (map (*) xs ys)
let matvecmul [n] [m] (xss: [n][m]f64) (ys: [m]f64) =
map (dotprod ys) xss
let cost_derivative [n] (output_activations:[n]f64) (y:[n]f64) : [n]f64 =
map (-) output_activations y
let outer_prod [m][n] (a:[m]f64) (b:[n]f64) : *[m][n]f64 =
map (\x -> map (\y -> x * y) b) a
let main [i] [j] [k] (w3:[k][j]f64) (x:[i]f64,y:[k]f64) (z2: []f64) (z3: [k]f64) =
let delta3 = map (*) (cost_derivative z3 y) z3
let nabla_b3 = delta3
let nabla_w3 = outer_prod delta3 z2
let delta2 = map (*) (matvecmul (rearrange (1,0) w3) delta3) z2
let nabla_b2 = delta2
let nabla_w2 = outer_prod delta2 x
let nabla2 = (nabla_b2,nabla_w2)
let nabla3 = (nabla_b3,nabla_w3)
in (nabla2,nabla3)
I get the following error:
Here is the code: