Open amb83 opened 3 years ago
I want to bump this because I feel pretty sure gradient
is giving a bad result, or otherwise somehow gets confused about stack
or RNN
or their combination. I tried another sanity check:
using ForwardDiff
h = [0.0]
function rnn(x, Wx)
global h = Wx .* x .+ [0.001] .* h .+ [0.85]
return h
end
loss2(Wx) = sum(Flux.stack(rnn.(x, Wx), 1) .- y)
g2(Wx) = ForwardDiff.gradient(loss2, Wx)
g2([0.7]) # = 2.8003
Here, rnn
should behave the same as Flux.RNN
. ForwardDiff.gradient
gives the result I expect, but Flux.gradient
does not. What is going on?
We have recently switched backends for RNNs, so if you could run the same tests on the previous tag of flux, it should help isolate what is happening
Doesn't look like that applies here because no GPU acceleration is in play, though I second the recommendation to try on a prior version.
Well, this actually turned up an interesting result. I tried this:
using Flux
x = [[rand()] for i = 1:2]
y = rand(2)
m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))
g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #should be true
m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #should be false
Using Flux v0.9 I get the result I expect, but using Flux v1.1, I get the wrong result (the true and false outputs are reversed).
Can you test on v0.10? 0.9 -> 0.10 was a complete overhaul of Flux's internals, so it's not suprising to see the results are different. Also can you verify that e.g. PyTorch does give the results you expect? I don't think it's possible to rule out that the RNN formulation used by these libraries is different than you'd expect.
I want to update that with v0.10, I get the incorrect result. Here's comparing v0.9 and v0.10:
# v0.9
using Flux
using Pkg
println(Pkg.status("Flux")) #Flux v0.9.0
x = [[rand()] for i = 1:2]
y = rand(2)
m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))
m.state *= 0
g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))
m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #true (should be true)
m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #false (should be false)
# v0.10
using Flux
using Pkg
println(Pkg.status("Flux")) #Flux v0.10.4
x = [[rand()] for i = 1:2]
y = rand(2)
m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))
m.state *= 0
g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))
m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #false (should be true)
m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #true (should be false)
As soon as I figure out how to use PyTorch I'll report those results.
Hi! After rewriting the loss in different ways I found that the problem is related to the broadcasting:
using Flux
h = [0.0]
function m(x)
y = Wx .* x + Wh .* h .+ b
global h = y
return y
end
x = [[0.3], [2.5]]
y = [0.5, 1.0]
Wx = [0.5]
Wh = [0.001]
b = [0.85]
# Theoretic derivative
x[1] + x[2] + Wh.*x[1] # = 2.8003
loss(x, y) = sum((Flux.stack(m.(x), 1) .- y))
gs = gradient(() -> loss(x, y), params(Wx, Wh, b))
gs[Wx] # = 2.8025
loss2(x, y) = sum((Flux.stack([m(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), params(Wx, Wh, b))
gs2[Wx] # = 2.8003
I tested it with an RNN layer and in this case the gradients match:
using Random
Random.seed!(1)
m2 = RNN(1, 1, x -> x)
x = [rand(Float32, 1) for i = 1:2]
y = rand(Float32, 2)
p = params(m2)
loss(x, y) = sum((Flux.stack(m2.(x), 1) .- y))
gs = gradient(() -> loss(x, y), p)
Flux.reset!(m2)
loss2(x, y) = sum((Flux.stack([m2(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), p)
[gs[pᵢ] .== gs2[pᵢ] for pᵢ in p]
# 4-element Vector{BitArray}:
# [1]
# [1]
# [1]
# [1]
However, if I add a trivial Chain, the gradients don't match anymore:
Random.seed!(1)
m2 = Chain(RNN(1, 1, x -> x))
x = [rand(Float32, 1) for i = 1:2]
y = rand(Float32, 2)
p = params(m2)
loss(x, y) = sum((Flux.stack(m2.(x), 1) .- y))
gs = gradient(() -> loss(x, y), p)
Flux.reset!(m2)
loss2(x, y) = sum((Flux.stack([m2(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), p)
[gs[pᵢ] .== gs2[pᵢ] for pᵢ in p]
# 4-element Vector{BitArray}:
# [0]
# [1]
# [0]
# [0]
For Dense layers, the gradients match in both cases.
Maybe I'm doing something wrong but if not, this could point to a serious bug, since most recurrent models are trained with broadcasting. I hope that this can be helpful and I take the chance to thank you for this amazing (and elegant) package!!
versioninfo()
Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin19.6.0)
CPU: Intel(R) Core(TM) i5-7360U CPU @ 2.30GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 4
(Flux_gradient_test) pkg> st
Status `~/Documents/github issues/Flux_gradient_test/Project.toml`
[587475ba] Flux v0.12.3
@gabrevaya the reason wrapping with a chain doesn't work is because of https://github.com/FluxML/Flux.jl/issues/1209. This was fixed for broadcasting Recur
in https://github.com/FluxML/Flux.jl/pull/1358, but as you can see from the signature the fix won't apply for a Chain
or any other layer wrapper.
An easy workaround that doesn't require an array comprehension is to use map
(the adjoint of which https://github.com/FluxML/Flux.jl/pull/1358 uses) like so:
loss2(x, y) = sum((Flux.stack(map(m2, x), 1) .- y))
For a proper fix, I think it's worth revisiting https://github.com/FluxML/Zygote.jl/pull/807.
I was still confused after my last issue, so I dove a little deeper. Specifically, I wanted to understand how RNN training works with the loss function used here:
loss(x, y) = sum((Flux.stack(m.(x),1) .- y) .^ 2)
I tested a much simplified version of this, using a 1 -> 1 RNN cell without an activation function, and the same loss function without the square:
Here's where I had a problem: when there's more than one input/output sample, what's the derivative of
L
with respect toWi
(m.cell.Wi
)? Using some semi-random values:If you evaluate
m.(x)
orL(x, y)
, you get the result you expect ([[1.06], [2.60106]]
and2.16106
, respectively). For dL/dWi, it's easy to derive by inspection that it should be x1 + x2 + Wh*x1 = 2.8003. You could also get it by finite difference:But when you use
gradient
:This result is equal to x1 + x2 + Wh*x2 instead of x1 + x2 + Wh*x1. Am I overlooking something, or is something weird happening here?