FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

Zygote gives extra gradient entries for BatchNorm #1018

Open xukai92 opened 4 years ago

xukai92 commented 4 years ago
using Flux

X = rand(2, 5)

layer = BatchNorm(2)
ps = params(layer)

gs = gradient(ps) do
    sum(layer(X))
end

gs.grads

gives

IdDict{Any,Any} with 5 entries:
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  BatchNorm(2)               => RefValue{Any}((λ = nothing, β = [5.0, 5.0], γ =…
  Float32[1.0, 1.0]          => [-9.22873e-16, 4.44089e-16]
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  Float32[0.0, 0.0]          => [5.0, 5.0]

FYI, ps is as expected:

Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])
haampie commented 4 years ago

I've been looking into this a bit. By writing

--- a/src/layers/normalise.jl
+++ b/src/layers/normalise.jl
@@ -153,7 +153,7 @@ function (BN::BatchNorm)(x)
     T = eltype(x)
     axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
     μ = mean(x, dims = axes)
-    σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
+    σ² = sum((x .- μ) .* (x .- μ), dims = axes) ./ m
     ϵ = convert(T, BN.ϵ)
     # update moving mean/std
     mtm = BN.momentum

the following two entries are no longer present:

  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  BatchNorm(2)               => RefValue{Any}((λ = nothing, β = [5.0, 5.0], γ =…

If I change .^2 to .^2.0f0, the gradient throws:

julia> example = gradient(ps) do
           sum(layer(X))
       end
ERROR: DomainError with -0.14608962594708508:
log will only return a complex result if called with a complex argument. Try log(Complex(x)).
haampie commented 4 years ago

And the BatchNorm(2) entry seems to happen because of mutable structs?

julia> mutable struct MutableLayer{T}; a::T; end;

julia> struct ImmutableLayer{T}; a::T; end;

julia> (layer::ImmutableLayer)() = sum(layer.a);

julia> (layer::MutableLayer)() = sum(layer.a);

julia> Flux.trainable(a::ImmutableLayer) = (a.a,);

julia> Flux.trainable(a::MutableLayer) = (a.a,);

julia> Flux.@functor ImmutableLayer;

julia> Flux.@functor MutableLayer;

julia> mutable_layer = MutableLayer(rand(1));

julia> immutable_layer = ImmutableLayer(rand(1));

julia> gradient(mutable_layer, params(mutable_layer)).grads
IdDict{Any,Any} with 2 entries:
  MutableLayer{Array{Float64,1}}([0.500088]) => RefValue{Any}((a = [1.0],))
  [0.500088]                                 => [1.0]

julia> gradient(immutable_layer, params(immutable_layer)).grads
IdDict{Any,Any} with 1 entry:
  [0.619024] => [1.0]
haampie commented 4 years ago

MWE for the ^ issue:

julia> f() = 1 .^ 2;

julia> gradient(f, params()).grads
IdDict{Any,Any} with 2 entries:
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
DhairyaLGandhi commented 4 years ago

I think this is known?

haampie commented 4 years ago

https://github.com/FluxML/Zygote.jl/pull/518 fixes the f() = 1 .^ 2 example and solves this issue for the most part.

The only entry left from @xukai92's example would be BatchNorm(2) => ..., but I don't know how to solve that one.