Closed lxvm closed 1 week ago
Fixes #254
Thank you for the review! I will address the first two points soon, and for the last one I don't seem to be able to reproduce any error, so I'll remove it. Perhaps it came up due to inadvertently breaking something while debugging.
I also wanted to ask, since I removed @adjoint! copyto!(x::Buffer, y::Any)
if we should restore a generic method, and if that should use indexing or iteration to get the elements of y
?
I also blindly removed this line
grad .= eltype(grad) <: Number ? 0 : nothing
Should I restore it?
I also wanted to ask, since I removed
@adjoint! copyto!(x::Buffer, y::Any)
if we should restore a generic method, and if that should use indexing or iteration to get the elements ofy
?
That would be best. What worries me more is that removing that copyto!
rule didn't lead to any failures. We probably need to add in tests to fill that gap. As for indexing vs iteration, the rule probably needs to rely on iteration because the copyto!
overload for Buffer
is so generic.
As for indexing vs iteration, the rule probably needs to rely on iteration
Zygote doesn't yet have an adjoint for collect(itr)
, so instead I've just reinstated the previous adjoint for copyto!(buffer, array)
. Presumably very few people write f!(out, args...)
in a situation where a buffer would be substituted for out
, although this is what I encountered in Integrals.jl. As for broadcasting, before this pr there has been an adjoint for Base.materialize!(buffer, array)
that called the one for copyto!
.
I've gone back and added some broken tests to point out where adjoints are missing for collect(itr)
and copyto!(buffer, itr)
Having done a fair amount of work on this, I'm not happy with the implementation yet. I think a lot of shortcomings of missing adjoints could be fixed by the following approach:
copyto!(buffer, x)
in terms of copyto!(buffer, collect(x))
collect(x)
should use a trait IsIndexable(x)
to determine whether it can use ∇map
to compute the adjoint. (Types like AbstractArray, Number, Broadcasted
that define getindex
would be indexable.)IsIndexable(x)
is false, then assume x
is iterable and use a (currently non-existent) collect
pullback for generic iterators. (e.g. types from Base.Iterators
and Generator
s would be iterable)Would this approach be sound? Any ideas?
I would prefer to even ditch the trait and just check for a set of known good types like the ones you listed to determine if x
can be indexed. Otherwise the overall plan sounds reasonable.
I actually don't have time to work on the improvements I to this pr I suggested, but in order to wrap up the changes I made there are two points to address:
copyto!(buffer, src)
adjoint had this line grad .= eltype(grad) <: Number ? 0 : nothing
and I removed it but no tests failed. What does this do exactly?copyto!
adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited:using Zygote
Zygote.gradient(collect(1:10)) do x
b = Zygote.Buffer(x)
tmp1 = sum(copy(b))
copyto!(b, fill(30))
tmp2 = sum(copy(b))
copyto!(b, [2i for i in 1:5])
tmp3 = sum(copy(b))
return tmp1 + tmp2 + tmp3
end # ERROR: Buffer is frozen
- the
copyto!(buffer, src)
adjoint had this linegrad .= eltype(grad) <: Number ? 0 : nothing
and I removed it but no tests failed. What does this do exactly?
I think it's for the case where you have a buffer of non-numbers. Examples would be a buffer of differentiable structs, or a buffer of arrays. That neither of these cases were tested is bad, but also not uncommon for Zygote (which historically has poor test coverage in general).
- The
copyto!
adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited: ...
My understanding is that differentiable arguments which are mutated should not have gradients returned for correctness reasons. Instead, a copy is kept in the mutable gradient cache managed by grad_mut
and returned at either the top level or the point at which the mutable value was constructed.
Thanks! I added back the zeroing out of grad
and kept the nothing
gradient of the buffer. I hope this is enough to complete the pr.
I added tests for Iterators.take
adjoints and rebased on the main branch
I said that this pr would fix #254 and I just wanted to say that its MWE is slightly broken
using Zygote
f = (du, u, p, t) -> du .= 0
(y, p, λ, t) = ([1.02634, 0.909691], [1.5, 1.0, 3.0, 1.0], [0.973655, 1.09031], 10.0)
_dy, back = Zygote.pullback(y) do u
out_ = Zygote.Buffer(u)
f(out_, u, p, t)
copy(out_)
end
dλ[:] = vec(back(λ)[1]) # ERROR: MethodError: no method matching vec(::Nothing)
The good news is that the gradient through the overwritten buffer is nothing
, but the MWE wasn't written to handle that.
Hi,
I'm using Zygote as an AD backend in Integrals.jl and while I was writing tests I noticed I couldn't assign a number to length-1
Buffer
using broadcasting. I think this is because the method signature for the pullback onmaterialize!
is too restrictive, sincecopyto!
allows for arbitrary iterators on the rhs ofbuf .= itr
. I also added a test for a MWE.PR Checklist
Update: A more complete MWE brings up a second issue:
MWE 2
```julia julia> using Zygote julia> gradient(1) do p b = Zygote.Buffer([1,2,3]) b .= p return sum(copy(b)) end ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}})(::Vector{Float64}) Closest candidates are: (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:121 (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Complex}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:192 (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Number}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:193 ... Stacktrace: [1] _project @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:189 [inlined] [2] map(f::typeof(Zygote._project), t::Tuple{Int64}, s::Tuple{Vector{Float64}}) @ Base ./tuple.jl:318 [3] gradient(::Function, ::Int64, ::Vararg{Int64}) @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:98 [4] top-level scope @ REPL[11]:1 ```Update 2: I added an adjoint for copyto! that fixes the MWE, however I'll try to add a generic adjoint for
copyto!(buffer, itr)
nextUpdate 3: I started about this the wrong way and the manual has details here on how to bypass broadcasting machinery, so an adjoint for
Base.materialize!
will have to be discarded and has to be replaced by an adjoint forcopyto!(buffer, broadcasted)
Update 4: I finished writing an adjoint and added a test for broadcasted assignment to a buffer from a generator. I'll happily incorporate any feedback and improvements.