FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 211 forks source link

gradients with aliased variables #991

Open CarloLucibello opened 3 years ago

CarloLucibello commented 3 years ago

I was trying to figure out how to properly handle and update Flux's layers with tied weights ( https://github.com/FluxML/Flux.jl/issues/1592).

So first of all I wanted to check how Zygote handles aliased objects. Here are 6 examples. Maybe it's all expected and intended but I find the last 3 in particular a bit surprising. @oxinabox is this what we want?

julia> using Zygote

julia> x = [1]
1-element Vector{Int64}:
 1

julia> xt = x'
1×1 adjoint(::Vector{Int64}) with eltype Int64:
 1

# 1.
julia> gradient(() -> sum(x' .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.x) => [2]
  [1]       => [2]

# 2.
julia> gradient(() -> sum(xt .* x), Params([x])).grads
IdDict{Any, Any} with 3 entries:
  :(Main.x)  => [1]
  [1]        => [1]
  :(Main.xt) => [1]

# 3.
julia> gradient(() -> sum(xt .* x), Params([x,xt])).grads
IdDict{Any, Any} with 4 entries:
  [1]        => [1]
  :(Main.x)  => [1]
  [1]        => [1]
  :(Main.xt) => [1]

# 4.
julia> gradient(() -> sum(xt.parent .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.x) => [1]
  [1]       => [2]

# 5.
julia> gradient(() -> sum(xt.parent .* x), Params([x, xt])).grads
IdDict{Any, Any} with 3 entries:
  [1]       => nothing  # this is xt
  :(Main.x) => [1]
  [1]       => [2]           # this is x

#6.
julia> gradient(() -> sum(xt.parent .* x), Params([xt])).grads
IdDict{Any, Any} with 3 entries:
  [1]        => (parent = [1],)
  :(Main.x)  => [1]
  :(Main.xt) => (parent = [1],)
CarloLucibello commented 3 years ago

I guess the most disturbing is 5., shouldn't return

  [1]       => (parent = [1],)  # this is xt
  :(Main.x) => [1]
  [1]       => [1]           # this is x

instead?

oxinabox commented 3 years ago

putting aliased memory in Params feels like its not going to be ok. I would need a fair bit of time to think about these.

darsnack commented 3 years ago

(never mind, the thing I was missing is scribbling the wrong variables on my napkin)

CarloLucibello commented 3 years ago

For a user define struct we have

julia> struct A; x; end

julia> x = rand(2); a = A(x);

julia> Base.sum(a) = sum(a.x)

julia> gradient(() -> sum(a), Params([x])).grads
IdDict{Any, Any} with 1 entry:
  [0.573261, 0.457937] => 2-element Fill{Float64}: entries equal to 1.0

while for Adjoint something is wrong

julia> xt = Adjoint(x)
1×2 adjoint(::Vector{Float64}) with eltype Float64:
 0.573261  0.457937

julia> gradient(() -> sum(xt), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  [0.573261, 0.457937] => nothing
  :(Main.xt)           => 1×2 Fill{Float64}: entries equal to 1.0
DhairyaLGandhi commented 3 years ago

This seems expected... The grads actually also track global params as a GlobalRef to capture tied variables.

darsnack commented 3 years ago

They make sense, but that doesn't make them right/useful.

I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.

CarloLucibello commented 3 years ago

Seems hard to not consider last example in https://github.com/FluxML/Zygote.jl/issues/991#issuecomment-864375988 a bug. I don't even know precisely why it happens, probably when we hit an AbstractArray{<:Number} in Zygote we don't look for internal structure, is that the case?

I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.

I'm not totally sure explicit gradient is a convenient fit for every situation, I'd like to see a diverse set of use cases where it replaces params. In last example, explicit gradient is at least consistent , although not quite useful

julia> gradient(x -> sum(a), x)
(nothing,)

julia> gradient(x -> sum(xt), x)
(nothing,)
darsnack commented 3 years ago

I think this illustrates why I consider explicit params better. It's obvious why the last example returned nothing. For the same reason, the Adjoint case returns nothing, but it is less obvious because we expect implicit params to pick up connections that aren't there in the function being differentiated.

One option is to add some kind of post-processing step where Params finds these connections and applies a fix. But I feel that it hard to do in the generic case correctly.

darsnack commented 3 years ago

For example, something like https://github.com/FluxML/Flux.jl/issues/1592 works out nicely. Similar to the examples above, if we have

m1 = Dense(5, 2)
m2 = Dense(transpose(m1.weight))
m = Chain(m1, m2)
dm = gradient(m -> sum(m(ones(Float32, 5))), m)[1]

Zygote will see the weight of m1 as w1 = w and m2 as w2 = transpose(w). It returns gradients w.r.t. w1 and w2 (as if they are not tied). But when we consider the part that Zygote doesn't see (w1 = w), we have

from multivariate chain rule
dL/dw = dL/dw1 * dw1/dw + dL/dw2 * dw2/dw
dw1/dw = 1
dw2/dw = 1 (up to transpose)

=> dL/dw = dL/dw1 + dL/dw2

The last equation is automatically done by simple optimizers like gradient descent provided you use lazy wrappers like transpose or views. (@oxinabox can correct me if I am wrong here, my AD knowledge is very limited).

I guess it isn't automatic for complex optimizers that track momentum, etc. But it seems like then we should be handling it on the optimizer side, not the AD. This is where I think explicit params is nicer. What I wrote above is true for implicit params as well (e.g. Example 3 in the main issue) when Params contains x, xt. The trouble with implicit params is that you get all these other cases, issues with hashing, etc. that make dealing with the final equation I wrote above harder on the optimizer side.