FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

gradient() fails on array mutation for `mean(f, x; dims)` #1128

Closed staticfloat closed 1 year ago

staticfloat commented 2 years ago

If you provide both an element-wise function f and a dimension specification, mean() apparently causes array mutation, which breaks Zygote's ability to differentiate:

julia> using Zygote, Statistics
       x = randn(3, 3)
       Zygote.gradient(Params([x])) do
           sum(mean(abs2, x, dims=1))
       end
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#441#442"{Matrix{Float64}})(#unused#::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:74
  [3] (::Zygote.var"#2330#back#443"{Zygote.var"#441#442"{Matrix{Float64}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof(∂(materialize!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [8] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:181 [inlined]
  [9] (::typeof(∂(_mean)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [10] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [11] (::typeof(∂(#mean#1)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [12] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [13] (::typeof(∂(mean##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./REPL[14]:4 [inlined]
 [15] (::typeof(∂(#17)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#89#90"{Params, typeof(∂(#17)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
 [17] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
 [18] top-level scope
    @ REPL[14]:3

Looking through the adjoints for mean() defined in lib/array.jl, I would guess that the fact that I'm passing abs2 in for f causes Zygote's implementation to be skipped altogether, and then the dims kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new @adjoint definition for one that includes f, but I don't know how to get the adjoint of a user-provided function.

mzgubic commented 2 years ago

but I don't know how to get the adjoint of a user-provided function

You could use Zygote.pullback to AD through it, which will get an adjoint if it exists. A PR to ChainRules would be well received, see https://github.com/JuliaDiff/ChainRules.jl/issues/85

mcabbott commented 2 years ago

The low-tech way to implement this is to turn it into broadcasting, as is currently done for sum(::Function, ::CuArray) here:

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L280-L283

oxinabox commented 2 years ago

reopened as i had to revert the fix