FluxML / ZygoteRules.jl

MIT License
15 stars 13 forks source link

Moving Buffer to ZygoteRules #6

Open sethaxen opened 5 years ago

sethaxen commented 5 years ago

Is Zygote.Buffer too hefty to move to ZygoteRules? I work with several packages that use mutation extensively, and it would be nice to add custom rules using Buffer without adding Zygote as a full dependency.

MikeInnes commented 4 years ago

I think it might make sense to make Buffers.jl a separate package. If you're willing to put it together, we can move it into FluxML, tag etc.

sethaxen commented 4 years ago

Sounds good! Happy to do it.

sethaxen commented 4 years ago

Here's an initial repo with a failing test: https://github.com/sethaxen/Buffers.jl. Buffer's adjoints require Zygote.grad_mut, which is dependent on Zygote.Context and Zygote.cache, not just ZygoteRules.AContext I don't really understand how grad_mut works. Any advice for how to replace?

MikeInnes commented 4 years ago

grad_mut is just a convenience that checks for a cached gradient for the given object, and creates one if it needs to. You could just copy-paste whatever definitions you need to make it work.

sethaxen commented 4 years ago

Sorry it took a while to get back to this. I've finished porting the Buffer-related code and tests to https://github.com/sethaxen/Buffers.jl, and the tests pass. However, after loading Buffers, Buffer seems to get pulled into adjoints that have nothing to do with it. example with these functions from Zygote's tests:

function pow_mut(x, n)
  r = Ref(one(x))
  while n > 0
    n -= 1
    r[] = r[] * x
  end
  return r[]
end

struct Foo{T}
  a::T
  b::T
end

function mul_struct(a, b)
  c = Foo(a, b)
  c.a * c.b
end

kwmul(; a = 1, b) = a*b

mul_kw(a, b) = kwmul(a = a, b = b)
julia> using Zygote

julia> gradient(pow_mut, 2, 3)
(nothing, 12)

julia> gradient(mul_struct, 2, 3)
(3, 2)

julia> gradient(mul_kw, 2, 3)
(3, 2)

julia> using Buffers

julia> gradient(pow_mut, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64)
Closest candidates are:
  Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] pow_mut at ./REPL[1]:2 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(pow_mut), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[11]:1

julia> gradient(mul_struct, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64, ::Int64)
Closest candidates are:
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Int64, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] mul_struct at ./REPL[3]:2 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(mul_struct), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[12]:1

julia> gradient(mul_kw, 2, 3)
ERROR: MethodError: no method matching Buffer(::Tuple{Int64,Int64})
Closest candidates are:
  Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Tuple{Int64,Int64}) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] mul_kw at ./REPL[5]:1 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(mul_kw), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[13]:1

Do you have any idea what can be going wrong?

MikeInnes commented 4 years ago

These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.

sethaxen commented 4 years ago

These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.

Those are all created in Zygote, but this package only depends on ZygoteRules, so I don't see how they can be type-pirating Zygote. Besides, changing the names completely produces the same errors.