Open sethaxen opened 5 years ago
I think it might make sense to make Buffers.jl
a separate package. If you're willing to put it together, we can move it into FluxML, tag etc.
Sounds good! Happy to do it.
Here's an initial repo with a failing test: https://github.com/sethaxen/Buffers.jl. Buffer
's adjoints require Zygote.grad_mut
, which is dependent on Zygote.Context
and Zygote.cache
, not just ZygoteRules.AContext
I don't really understand how grad_mut
works. Any advice for how to replace?
grad_mut
is just a convenience that checks for a cached gradient for the given object, and creates one if it needs to. You could just copy-paste whatever definitions you need to make it work.
Sorry it took a while to get back to this. I've finished porting the Buffer
-related code and tests to https://github.com/sethaxen/Buffers.jl, and the tests pass. However, after loading Buffers, Buffer
seems to get pulled into adjoints that have nothing to do with it. example with these functions from Zygote's tests:
function pow_mut(x, n)
r = Ref(one(x))
while n > 0
n -= 1
r[] = r[] * x
end
return r[]
end
struct Foo{T}
a::T
b::T
end
function mul_struct(a, b)
c = Foo(a, b)
c.a * c.b
end
kwmul(; a = 1, b) = a*b
mul_kw(a, b) = kwmul(a = a, b = b)
julia> using Zygote
julia> gradient(pow_mut, 2, 3)
(nothing, 12)
julia> gradient(mul_struct, 2, 3)
(3, 2)
julia> gradient(mul_kw, 2, 3)
(3, 2)
julia> using Buffers
julia> gradient(pow_mut, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64)
Closest candidates are:
Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
[1] _pullback(::Zygote.Context, ::UnionAll, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
[2] pow_mut at ./REPL[1]:2 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(pow_mut), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
[5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
[6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
[7] top-level scope at REPL[11]:1
julia> gradient(mul_struct, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64, ::Int64)
Closest candidates are:
Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
[1] _pullback(::Zygote.Context, ::UnionAll, ::Int64, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
[2] mul_struct at ./REPL[3]:2 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(mul_struct), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
[5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
[6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
[7] top-level scope at REPL[12]:1
julia> gradient(mul_kw, 2, 3)
ERROR: MethodError: no method matching Buffer(::Tuple{Int64,Int64})
Closest candidates are:
Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
[1] _pullback(::Zygote.Context, ::UnionAll, ::Tuple{Int64,Int64}) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
[2] mul_kw at ./REPL[5]:1 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(mul_kw), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
[5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
[6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
[7] top-level scope at REPL[13]:1
Do you have any idea what can be going wrong?
These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.
These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.
Those are all created in Zygote, but this package only depends on ZygoteRules, so I don't see how they can be type-pirating Zygote. Besides, changing the names completely produces the same errors.
Is
Zygote.Buffer
too hefty to move toZygoteRules
? I work with several packages that use mutation extensively, and it would be nice to add custom rules usingBuffer
without addingZygote
as a full dependency.