dfdx / XGrad.jl

eXpression gradients in Julia
Other
3 stars 4 forks source link

Seed != 1.0 #8

Closed dfdx closed 6 years ago

dfdx commented 6 years ago

For example, in one iteration of RNN we have:

h[t+1] = RNN(W, x, h[t])

where h[t + 1] is a tensor itself. We can't find dh[t+1]/dh[t] efficiently and thus can't make a separate function to calculate it, but if already have dy/dh[t+1], we can easily write down an expression for dy/dh[t].

Essentially, to do so we need to replace the current seed value from constant 1.0 to an input parameter seed. As far as I can see, 2 pieces of code need to be updated:

dfdx commented 6 years ago

Implemented via context parameter :seed, e.g.:

ex = :(y = W * x)
W = rand(2,3); x = rand(3)
ctx = Dict(:seed => [1.0, 0.0], :codegen => VectorCodeGen())  # codegen just for readability

xdiff(ex; ctx=ctx, W=W, x=x)

which produces:

quote
    tmp659 = transpose(x)
    tmp661 = transpose(W)
    y = W * x
    dy!dy = [1.0, 0.0]
    dy!dx = tmp661 * dy!dy
    dy!dW = dy!dy * tmp659
    tmp663 = (y, dy!dW, dy!dx)
end