FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Behavior of real changes when broadcasted #356

Closed sethaxen closed 5 years ago

sethaxen commented 5 years ago

Calling real on a matrix or broadcasting it over the matrix produces different adjoints when a complex adjoint is pulled back:

julia> using Zygote, LinearAlgebra, Random
julia> Random.seed!(42);
julia> A = randn(4, 4);
julia> R̄ = randn(ComplexF64, 4, 4);
julia> y, back1 = Zygote.pullback(real, A)
([-0.5560268761463861 1.7778610980573246 -2.641991008076796 0.5181487878771377; -0.444383357109696 -1.14490153172882 1.0033099014594844 1.4913791170403063; 0.027155338009193845 -0.46860588216767457 1.0823812056084292 0.3675627461748204; -0.29948409035891055 0.15614346264074028 0.18702790710363 -0.8862052960481365], getfield(Zygote, Symbol("##28#29")){typeof(∂(real))}(∂(real)))
julia> y, back2 = Zygote.pullback(x->real.(x), A)
([-0.5560268761463861 1.7778610980573246 -2.641991008076796 0.5181487878771377; -0.444383357109696 -1.14490153172882 1.0033099014594844 1.4913791170403063; 0.027155338009193845 -0.46860588216767457 1.0823812056084292 0.3675627461748204; -0.29948409035891055 0.15614346264074028 0.18702790710363 -0.8862052960481365], getfield(Zygote, Symbol("##28#29")){typeof(∂(#3))}(∂(#3)))
julia> back1(R̄)[1]
4×4 Array{Complex{Float64},2}:
   0.48406-1.12471im   -3.06705+0.40293im      0.402668+0.4816im      1.28647-0.259692im 
  0.290376-0.605531im  -1.00555-0.263327im    -0.946904-0.168492im   0.534975+0.0615302im
 -0.743161+0.355024im   0.26093-0.00538319im   0.720796+0.496227im  -0.601885+0.653535im 
  -0.15291-0.499517im  0.397867+0.0755679im   -0.102679+0.454596im    1.13992+0.347891im 
julia> back2(R̄)[1]
4×4 Array{Float64,2}:
  0.48406   -3.06705    0.402668   1.28647 
  0.290376  -1.00555   -0.946904   0.534975
 -0.743161   0.26093    0.720796  -0.601885
 -0.15291    0.397867  -0.102679   1.13992 

The behavior of the 2nd (broadcasted) is consistent with scalar functions. Is there a reason why the two reals have different results, or is this a bug?

MikeInnes commented 5 years ago

It's not intentional, no. Not sure what would cause this, but it should be an easy fix if someone can dig into it.

sethaxen commented 5 years ago

I can give it a try. The broadcast case is the correct one, right?

mcabbott commented 5 years ago

It looks like there’s an explicit gradient definition for real acting on numbers, or broadcasted, but not acting on arrays:

https://github.com/FluxML/Zygote.jl/blob/d86b07ba0039bada4bb34d2eeb4940e23744c75e/src/lib/number.jl#L59

https://github.com/FluxML/Zygote.jl/blob/ea4d1e894af775a6783f683d851bcfb5c10e25b0/src/lib/broadcast.jl#L88

sethaxen commented 5 years ago

This looks like it might be a broadcast issue. On 1.2, real(x) just calls broadcast(real, x)

julia> @code_lowered real(A)
CodeInfo(
1 ─ %1 = Base.broadcast(Base.real, A)
└──      return %1
)

julia> @code_lowered Base.broadcast(Base.real, A)
CodeInfo(
1 ─ %1 = Core.tuple(f)
│   %2 = Core._apply(Base.Broadcast.broadcasted, %1, As)
│   %3 = Base.Broadcast.materialize(%2)
└──      return %3
)

so in principle custom adjoints from broadcasted should be used. On the other hand, broadcasting with real.() gives you

julia> g(x) = real.(x)
g (generic function with 1 method)

julia> @code_lowered g(A)
CodeInfo(
1 ─ %1 = Base.broadcasted(Main.real, x)
│   %2 = Base.materialize(%1)
└──      return %2
)

As far as I can tell, the only difference is that the first one uses _apply. I didn't see much on _apply, but it seems to be called because of the splatting in broadcast's default.

julia> myfun(f, x) = f(x);

julia> myfun2(f, x...) = f(x...);

julia> @code_lowered myfun(identity, 3.0)
CodeInfo(
1 ─ %1 = (f)(x)
└──      return %1
)

julia> @code_lowered myfun2(identity, 3.0)
CodeInfo(
1 ─ %1 = Core._apply(f, x)
└──      return %1
)

Could it be that _apply is somehow circumventing the custom adjoints for broadcasted?

MikeInnes commented 5 years ago

Thanks for digging into this. It's an interesting case. We end up calling

https://github.com/JuliaLang/julia/blob/9d98131c4fd460bee1e79b14a2598b7d8d8289d8/base/abstractarraymath.jl#L94

which means real is the identity (and gets the appropriate adjoint).

For now we should just fix with a custom adjoint. It's concerning that code like this can lead to bad gradients, though, since presumably there are other cases where people might specialise on the reals rather than writing code generically across all numbers.