JuliaManifolds / Manifolds.jl

Manifolds.jl provides a library of manifolds aiming for an easy-to-use and fast implementation.
https://juliamanifolds.github.io/Manifolds.jl
MIT License
372 stars 55 forks source link

Adding Zygote compatibility #42

Closed sethaxen closed 3 years ago

sethaxen commented 5 years ago

It would be nice to add compatibility for Zygote. I know I personally need Zygote for its complex number support, and I'd also like to use Manifolds in the same code.

For the most part, we can probably expect Zygote to just work, with one major block: Zygote doesn't support mutation. I don't think this will be a big problem though. Zygote offers Zygote.Buffer, which behaves just like similar and allows mutation. All we should need to do is add something like this

_similar(args...) = similar(args...)

function similar_result(M::Manifold, f, x...)
    T = similar_result_type(M, f, x)
    return _similar(x[1], T)
end

finalize_result(M::Manifold, f, y, x...) = y

@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
   using Zygote
    Zygote.@adjoint _similar(args...) = Zygote.Buffer(args...), _->nothing
    finalize_result(M::Manifold, f, y::Zygote.Buffer, x...) = copy(y)
end

Every non-mutating function that uses similar_result then returns the output of finalize_result, and Zygote should just work.

Of course, this just brings Zygote support to the same level as ReverseDiff. It doesn't handle the issues raised in #17. Although it should be easier to define custom behavior for embedded manifolds with Zygote than with ReverseDiff.

mateuszbaran commented 5 years ago

I'd very much like to see Zygote support too. As far as I understand Zygote, the first choice should generally be not mutating at all and writing in a more functional style which would be quite possible here: we would just have to explicitly write non-mutating methods for all manifolds instead of relying on mutating variants and similar_result. Anyway, what you suggest looks like a good first step.

sethaxen commented 5 years ago

As far as I understand Zygote, the first choice should generally be not mutating at all

I don't know if that's the case long term. I think in the short term Buffer is meant to be the solution.

My thought is to start by adding a macro for defining new functions that have mutating and non-mutating versions, since we have a lot of duplication there. That would use the above pattern with finalize_result, so we would get Zygote compatibility. Then we could either modify the macro to independently create mutating and non-mutating or just drop it altogether to get functional.

For custom adjoints, in the short term we can define them using ZygoteRules.@adjoint, but once it uses ChainRules, we can migrate to using ChainRulesCore.rrule.

mateuszbaran commented 5 years ago

I don't know if that's the case long term. I think in the short term Buffer is meant to be the solution.

For one thing, Buffer isn't a StaticArray and therefore performance on small arrays will suffer. I wonder if we can effectively wrap Buffer in SizedAbstractArray or HybridArray.

My thought is to start by adding a macro for defining new functions that have mutating and non-mutating versions, since we have a lot of duplication there. That would use the above pattern with finalize_result, so we would get Zygote compatibility. Then we could either modify the macro to independently create mutating and non-mutating or just drop it altogether to get functional.

Making such a macro may be quite complicated. But it could work.

For custom adjoints, in the short term we can define them using ZygoteRules.@adjoint, but once it uses ChainRules, we can migrate to using ChainRulesCore.rrule.

Sounds reasonable, the more AD backends we can support the better.

sethaxen commented 5 years ago

For one thing, Buffer isn't a StaticArray and therefore performance on small arrays will suffer. I wonder if we can effectively wrap Buffer in SizedAbstractArray or HybridArray.

I haven't mixed Zygote with StaticArrays extensively, but in my simple tests, it was slower than using a standard Array. May have just been a quirk of the function I was using.

If passed a static array, Buffer will just wrap the output of similar, and forward most operations to it, so I don't see why it shouldn't get the same speed-ups. Here's a simple test where it seems to do so.

julia> using Zygote, StaticArrays, BenchmarkTools

julia> function foo(x)
           y = similar(x)
           copyto!(y, x^2)
           return y
       end
foo (generic function with 1 method)

julia> function bar(x)
           y = Zygote.Buffer(x)
           copyto!(y, x^2)
           return copy(y)
       end
bar (generic function with 1 method)

julia> x = @SMatrix randn(3,3);

julia> @benchmark foo(x)
BenchmarkTools.Trial: 
  memory estimate:  80 bytes
  allocs estimate:  1
  --------------
  minimum time:     41.659 ns (0.00% GC)
  median time:      43.811 ns (0.00% GC)
  mean time:        58.513 ns (12.49% GC)
  maximum time:     7.474 μs (98.40% GC)
  --------------
  samples:          10000
  evals/sample:     990

julia> @benchmark bar(x)
BenchmarkTools.Trial: 
  memory estimate:  80 bytes
  allocs estimate:  1
  --------------
  minimum time:     42.524 ns (0.00% GC)
  median time:      44.118 ns (0.00% GC)
  mean time:        59.761 ns (12.47% GC)
  maximum time:     8.164 μs (98.92% GC)
  --------------
  samples:          10000
  evals/sample:     990

julia> x2 = collect(x);

julia> @benchmark foo(x2)
BenchmarkTools.Trial: 
  memory estimate:  320 bytes
  allocs estimate:  2
  --------------
  minimum time:     127.207 ns (0.00% GC)
  median time:      145.000 ns (0.00% GC)
  mean time:        184.259 ns (7.84% GC)
  maximum time:     5.387 μs (94.93% GC)
  --------------
  samples:          10000
  evals/sample:     900

Making such a macro may be quite complicated. But it could work.

I looked at MacroTools and instantly regretted this idea. I'll just change all the functions with similar_result to add finalize_result. Long term though, I do think we need to work out which pattern is the default, mutating or non-mutating, since that determines what we recommend users overload.

mateuszbaran commented 5 years ago

I've taken a look at Zygote.Buffer and I have a better understanding of what it does now. One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of StaticArrays is not designed to handle Buffer. That could be fixed though.

Do you know why Buffer is not a subtype of AbstractArray?

sethaxen commented 5 years ago

One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of StaticArrays is not designed to handle Buffer. That could be fixed though.

I have a local branch where I've begun work on this, and I'll open a breaking WIP PR.

Do you know why Buffer is not a subtype of AbstractArray?

Not specifically. It seems to be its main point though. I think the idea is to sidestep AbstractArray's pullbacks, which requires a different type. See https://github.com/FluxML/Zygote.jl/blob/a533d9954341725b7df7b10722a9503238767917/src/lib/buffer.jl#L30-L32.

mateuszbaran commented 5 years ago

One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of StaticArrays is not designed to handle Buffer. That could be fixed though.

I have a local branch where I've begun work on this, and I'll open a breaking WIP PR.

Great! I have some experience with broadcasting in StaticArrays (I've duplicated a part of it in HybridArrays), could you let me know when you open that PR?

Do you know why Buffer is not a subtype of AbstractArray?

Not specifically. It seems to be its main point though. I think the idea is to sidestep AbstractArray's pullbacks, which requires a different type. See https://github.com/FluxML/Zygote.jl/blob/a533d9954341725b7df7b10722a9503238767917/src/lib/buffer.jl#L30-L32.

I don't fully understand that but that apparently doesn't work with broadcasting at all? So you have to do quite a lot in that PR :slightly_smiling_face: .

julia> a = Zygote.Buffer([1.0 2.0; 3. 4.])
Zygote.Buffer{Float64,Array{Float64,2}}([1.390671161567e-309 0.0; 6.9043377050637e-310 0.0], false)

julia> a .*= 2
ERROR: MethodError: no method matching iterate(::Zygote.Buffer{Float64,Array{Float64,2}})
Closest candidates are:
  iterate(::Core.SimpleVector) at essentials.jl:604
  iterate(::Core.SimpleVector, ::Any) at essentials.jl:604
  iterate(::ExponentialBackOff) at error.jl:214
  ...
Stacktrace:
 [1] copyto!(::Array{Float64,1}, ::Zygote.Buffer{Float64,Array{Float64,2}}) at ./abstractarray.jl:722
 [2] _collect(::UnitRange{Int64}, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Base.HasEltype, ::Base.HasLength) at ./array.jl:550
 [3] collect(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./array.jl:544
 [4] broadcastable(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./broadcast.jl:659
 [5] broadcasted(::Function, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Int64) at ./broadcast.jl:1213
 [6] top-level scope at REPL[16]:1
sethaxen commented 5 years ago

Great! I have some experience with broadcasting in StaticArrays (I've duplicated a part of it in HybridArrays), could you let me know when you open that PR?

Sure, though I think you'll get the notification anyways. I'm going to start here, and if it looks like changes are needed to StaticArrays, we can go that route.

I don't fully understand that but that apparently doesn't work with broadcasting at all? So you have to do quite a lot in that PR 🙂 .

😧 well this will be fun.

Buffer isn't meant to do any math, just to provide an interface for mutating mainly via setindex! and copyto!. In most of our mutating functions, the mutated array is only mutated at the very end of the function in what is essentially a copyto! (making this an explicit copyto! will be necessary), so that's easy to support. The project_XXX! functions used at the end of exp! and log! are one of the exceptions. But I think we can make it work, even if it adds 1 extra allocation.

The tricky thing is making sure we've considered the edge cases with our special array-wrapping types like FVector and ProductRepr.

sethaxen commented 5 years ago

After putting some time into this, I now think that the functional approach will be cleaner and easier to support going forward. A few downsides: 1) Some loss in efficiency, but how much? 2) Optim.jl will require manifolds to have 2 functions: project_tangent! and retract!. Adopting a functional pattern would likely have these mutate only after allocating. I don't know how our SizedStaticArrays or the new HybridArray`s being used in #40 would be impacted by not mutating. I don't think these downsides are extreme, but I'd like to open it up to discussion.

mateuszbaran commented 5 years ago

One possible way forward is to have both mutating and non-mutating versions of all functions. Having only non-mutating versions is not acceptable for me as this would very significantly slow down my computations. I've originally written my FunManifolds.jl code in a non-mutating style and then I was slowly replacing it with mutating code which resulted in massive speed-ups.

sethaxen commented 5 years ago

One possible way forward is to have both mutating and non-mutating versions of all functions.

Ugh, that will just be so frustrating to keep synchronized. It's too bad we don't have better tools for writing general interfaces that support mutating and functional styles in base Julia. I've heard rumors that some such features might be coming for 2.0.

In the meantime, maybe getting Buffer working and keeping the base mutating style is the best way to go. I suspect mutating code will always have some issues, but the idea would be that when using our non-mutating functions, Zygote should just work as though they didn't internally mutate.

mateuszbaran commented 5 years ago

Ugh, that will just be so frustrating to keep synchronized.

Generic tests can take care of checking that everything works right. Having both variants might be less frustrating than dealing with Buffer though, especially for product and power manifolds.

It's too bad we don't have better tools for writing general interfaces that support mutating and functional styles in base Julia. I've heard rumors that some such features might be coming for 2.0.

I didn't know that 2.0 could have a better support for that. I wonder how that would work. Anyway, what we have looks more like a Zygote problem than a Julia problem.

In the meantime, maybe getting Buffer working and keeping the base mutating style is the best way to go. I suspect mutating code will always have some issues, but the idea would be that when using our non-mutating functions, Zygote should just work as though they didn't internally mutate.

Right, and mutation might not even be the biggest problem on the road to Zygote compatibility. We'll see.