Strange result with gradient #1547

amb83 opened 3 years ago

amb83 commented 3 years ago

I was still confused after my last issue, so I dove a little deeper. Specifically, I wanted to understand how RNN training works with the loss function used here:

loss(x, y) = sum((Flux.stack(m.(x),1) .- y) .^ 2)

I tested a much simplified version of this, using a 1 -> 1 RNN cell without an activation function, and the same loss function without the square:

m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))

Here's where I had a problem: when there's more than one input/output sample, what's the derivative of L with respect to Wi (m.cell.Wi)? Using some semi-random values:

x = [[0.3], [2.5]]
y = [0.5, 1.0]
m.cell.Wi .= [0.7]   #Wi
m.cell.Wh .= [0.001] #Wh
m.cell.b .= [0.85]   #b
m.state = [0.0]     #h0, h1, h2, etc.

If you evaluate m.(x) or L(x, y), you get the result you expect ([[1.06], [2.60106]] and 2.16106, respectively). For dL/dWi, it's easy to derive by inspection that it should be x1 + x2 + Wh*x1 = 2.8003. You could also get it by finite difference:

q = L(x, y)
m.cell.Wi .+= 0.01
m.state = [0.0]
r = L(x, y)
abs(q - r)/0.01 # = 2.8003

But when you use gradient:

m.state = [0.0]
m.cell.Wi .= [0.7]
g = gradient(() -> L(x, y), params(m.cell.Wi))
g[m.cell.Wi] # = 2.8025

This result is equal to x1 + x2 + Wh*x2 instead of x1 + x2 + Wh*x1. Am I overlooking something, or is something weird happening here?

amb83 commented 3 years ago

I want to bump this because I feel pretty sure gradient is giving a bad result, or otherwise somehow gets confused about stack or RNN or their combination. I tried another sanity check:

using ForwardDiff

h = [0.0]
function rnn(x, Wx)
    global h = Wx .* x .+ [0.001] .* h .+ [0.85]
    return h

loss2(Wx) = sum(Flux.stack(rnn.(x, Wx), 1) .- y)
g2(Wx) = ForwardDiff.gradient(loss2, Wx)

g2([0.7]) # = 2.8003

Here, rnn should behave the same as Flux.RNN. ForwardDiff.gradient gives the result I expect, but Flux.gradient does not. What is going on?

DhairyaLGandhi commented 3 years ago

We have recently switched backends for RNNs, so if you could run the same tests on the previous tag of flux, it should help isolate what is happening

ToucheSir commented 3 years ago

Doesn't look like that applies here because no GPU acceleration is in play, though I second the recommendation to try on a prior version.

amb83 commented 3 years ago

Well, this actually turned up an interesting result. I tried this:

using Flux

x = [[rand()] for i = 1:2]
y = rand(2)

m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))

g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #should be true

m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #should be false

Using Flux v0.9 I get the result I expect, but using Flux v1.1, I get the wrong result (the true and false outputs are reversed).

ToucheSir commented 3 years ago

Can you test on v0.10? 0.9 -> 0.10 was a complete overhaul of Flux's internals, so it's not suprising to see the results are different. Also can you verify that e.g. PyTorch does give the results you expect? I don't think it's possible to rule out that the RNN formulation used by these libraries is different than you'd expect.

amb83 commented 3 years ago

I want to update that with v0.10, I get the incorrect result. Here's comparing v0.9 and v0.10:

# v0.9

using Flux

using Pkg
println(Pkg.status("Flux")) #Flux v0.9.0

x = [[rand()] for i = 1:2]
y = rand(2)

m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))

m.state *= 0
g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))

m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #true (should be true)

m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #false (should be false)
# v0.10

using Flux

using Pkg
println(Pkg.status("Flux")) #Flux v0.10.4

x = [[rand()] for i = 1:2]
y = rand(2)

m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))

m.state *= 0
g = Flux.gradient(() -> L(x, y), params(m.cell.Wi))

m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[1]) #false (should be true)

m.state *= 0
@show g[m.cell.Wi] ≈ (x[1] + x[2] + m.cell.Wh*x[2]) #true (should be false)

As soon as I figure out how to use PyTorch I'll report those results.

gabrevaya commented 3 years ago

Hi! After rewriting the loss in different ways I found that the problem is related to the broadcasting:

using Flux

h = [0.0]
function m(x)
    y = Wx .* x + Wh .* h .+ b
    global h = y
    return y

x = [[0.3], [2.5]]
y = [0.5, 1.0]

Wx = [0.5]
Wh = [0.001]
b = [0.85]

# Theoretic derivative
x[1] + x[2] + Wh.*x[1] # = 2.8003

loss(x, y) = sum((Flux.stack(m.(x), 1) .- y))
gs = gradient(() -> loss(x, y), params(Wx, Wh, b))
gs[Wx] # = 2.8025

loss2(x, y) = sum((Flux.stack([m(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), params(Wx, Wh, b))
gs2[Wx] # = 2.8003

I tested it with an RNN layer and in this case the gradients match:

using Random

m2 = RNN(1, 1, x -> x)
x = [rand(Float32, 1) for i = 1:2]
y = rand(Float32, 2)

p = params(m2)
loss(x, y) = sum((Flux.stack(m2.(x), 1) .- y))
gs = gradient(() -> loss(x, y), p)

loss2(x, y) = sum((Flux.stack([m2(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), p)

[gs[pᵢ] .== gs2[pᵢ] for pᵢ in p]
# 4-element Vector{BitArray}:
#  [1]
#  [1]
#  [1]
#  [1]

However, if I add a trivial Chain, the gradients don't match anymore:

m2 = Chain(RNN(1, 1, x -> x))
x = [rand(Float32, 1) for i = 1:2]
y = rand(Float32, 2)

p = params(m2)
loss(x, y) = sum((Flux.stack(m2.(x), 1) .- y))
gs = gradient(() -> loss(x, y), p)

loss2(x, y) = sum((Flux.stack([m2(xᵢ) for xᵢ in x], 1) .- y))
gs2 = gradient(() -> loss2(x, y), p)

[gs[pᵢ] .== gs2[pᵢ] for pᵢ in p]
# 4-element Vector{BitArray}:
#  [0]
#  [1]
#  [0]
#  [0]

For Dense layers, the gradients match in both cases.

Maybe I'm doing something wrong but if not, this could point to a serious bug, since most recurrent models are trained with broadcasting. I hope that this can be helpful and I take the chance to thank you for this amazing (and elegant) package!!

Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Intel(R) Core(TM) i5-7360U CPU @ 2.30GHz
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)

(Flux_gradient_test) pkg> st
      Status `~/Documents/github issues/Flux_gradient_test/Project.toml`
  [587475ba] Flux v0.12.3
ToucheSir commented 3 years ago

@gabrevaya the reason wrapping with a chain doesn't work is because of https://github.com/FluxML/Flux.jl/issues/1209. This was fixed for broadcasting Recur in https://github.com/FluxML/Flux.jl/pull/1358, but as you can see from the signature the fix won't apply for a Chain or any other layer wrapper.

An easy workaround that doesn't require an array comprehension is to use map (the adjoint of which https://github.com/FluxML/Flux.jl/pull/1358 uses) like so:

loss2(x, y) = sum((Flux.stack(map(m2, x), 1) .- y))

For a proper fix, I think it's worth revisiting https://github.com/FluxML/Zygote.jl/pull/807.