FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

fix broadcasting into buffers #1488

Closed lxvm closed 1 week ago

lxvm commented 9 months ago

Hi,

I'm using Zygote as an AD backend in Integrals.jl and while I was writing tests I noticed I couldn't assign a number to length-1 Buffer using broadcasting. I think this is because the method signature for the pullback on materialize! is too restrictive, since copyto! allows for arbitrary iterators on the rhs of buf .= itr. I also added a test for a MWE.

PR Checklist

Update: A more complete MWE brings up a second issue:

MWE 2 ```julia julia> using Zygote julia> gradient(1) do p b = Zygote.Buffer([1,2,3]) b .= p return sum(copy(b)) end ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}})(::Vector{Float64}) Closest candidates are: (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:121 (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Complex}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:192 (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Number}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:193 ... Stacktrace: [1] _project @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:189 [inlined] [2] map(f::typeof(Zygote._project), t::Tuple{Int64}, s::Tuple{Vector{Float64}}) @ Base ./tuple.jl:318 [3] gradient(::Function, ::Int64, ::Vararg{Int64}) @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:98 [4] top-level scope @ REPL[11]:1 ```

Update 2: I added an adjoint for copyto! that fixes the MWE, however I'll try to add a generic adjoint for copyto!(buffer, itr) next

Update 3: I started about this the wrong way and the manual has details here on how to bypass broadcasting machinery, so an adjoint for Base.materialize! will have to be discarded and has to be replaced by an adjoint for copyto!(buffer, broadcasted)

Update 4: I finished writing an adjoint and added a test for broadcasted assignment to a buffer from a generator. I'll happily incorporate any feedback and improvements.

lxvm commented 9 months ago

Fixes #254

lxvm commented 8 months ago

Thank you for the review! I will address the first two points soon, and for the last one I don't seem to be able to reproduce any error, so I'll remove it. Perhaps it came up due to inadvertently breaking something while debugging.

I also wanted to ask, since I removed @adjoint! copyto!(x::Buffer, y::Any) if we should restore a generic method, and if that should use indexing or iteration to get the elements of y?

lxvm commented 8 months ago

I also blindly removed this line

grad .= eltype(grad) <: Number ? 0 : nothing

Should I restore it?

ToucheSir commented 8 months ago

I also wanted to ask, since I removed @adjoint! copyto!(x::Buffer, y::Any) if we should restore a generic method, and if that should use indexing or iteration to get the elements of y?

That would be best. What worries me more is that removing that copyto! rule didn't lead to any failures. We probably need to add in tests to fill that gap. As for indexing vs iteration, the rule probably needs to rely on iteration because the copyto! overload for Buffer is so generic.

lxvm commented 8 months ago

As for indexing vs iteration, the rule probably needs to rely on iteration

Zygote doesn't yet have an adjoint for collect(itr), so instead I've just reinstated the previous adjoint for copyto!(buffer, array). Presumably very few people write f!(out, args...) in a situation where a buffer would be substituted for out, although this is what I encountered in Integrals.jl. As for broadcasting, before this pr there has been an adjoint for Base.materialize!(buffer, array) that called the one for copyto!.

I've gone back and added some broken tests to point out where adjoints are missing for collect(itr) and copyto!(buffer, itr)

Having done a fair amount of work on this, I'm not happy with the implementation yet. I think a lot of shortcomings of missing adjoints could be fixed by the following approach:

Would this approach be sound? Any ideas?

ToucheSir commented 8 months ago

I would prefer to even ditch the trait and just check for a set of known good types like the ones you listed to determine if x can be indexed. Otherwise the overall plan sounds reasonable.

lxvm commented 8 months ago

I actually don't have time to work on the improvements I to this pr I suggested, but in order to wrap up the changes I made there are two points to address:

  1. the copyto!(buffer, src) adjoint had this line grad .= eltype(grad) <: Number ? 0 : nothing and I removed it but no tests failed. What does this do exactly?
  2. The copyto! adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited:
using Zygote
Zygote.gradient(collect(1:10)) do x
    b = Zygote.Buffer(x)
    tmp1 = sum(copy(b))
    copyto!(b, fill(30))
    tmp2 = sum(copy(b))
    copyto!(b, [2i for i in 1:5])
    tmp3 = sum(copy(b))
    return tmp1 + tmp2 + tmp3
end # ERROR: Buffer is frozen
ToucheSir commented 8 months ago
  • the copyto!(buffer, src) adjoint had this line grad .= eltype(grad) <: Number ? 0 : nothing and I removed it but no tests failed. What does this do exactly?

I think it's for the case where you have a buffer of non-numbers. Examples would be a buffer of differentiable structs, or a buffer of arrays. That neither of these cases were tested is bad, but also not uncommon for Zygote (which historically has poor test coverage in general).

  • The copyto! adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited: ...

My understanding is that differentiable arguments which are mutated should not have gradients returned for correctness reasons. Instead, a copy is kept in the mutable gradient cache managed by grad_mut and returned at either the top level or the point at which the mutable value was constructed.

lxvm commented 7 months ago

Thanks! I added back the zeroing out of grad and kept the nothing gradient of the buffer. I hope this is enough to complete the pr.

lxvm commented 7 months ago

I added tests for Iterators.take adjoints and rebased on the main branch

lxvm commented 7 months ago

I said that this pr would fix #254 and I just wanted to say that its MWE is slightly broken

using Zygote
f = (du, u, p, t) -> du .= 0
(y, p, λ, t) = ([1.02634, 0.909691], [1.5, 1.0, 3.0, 1.0], [0.973655, 1.09031], 10.0)
_dy, back = Zygote.pullback(y) do u
  out_ = Zygote.Buffer(u)
  f(out_, u, p, t)
  copy(out_)
end
dλ[:] = vec(back(λ)[1]) # ERROR: MethodError: no method matching vec(::Nothing)

The good news is that the gradient through the overwritten buffer is nothing, but the MWE wasn't written to handle that.