Closed sethaxen closed 3 years ago
I'd very much like to see Zygote support too. As far as I understand Zygote, the first choice should generally be not mutating at all and writing in a more functional style which would be quite possible here: we would just have to explicitly write non-mutating methods for all manifolds instead of relying on mutating variants and similar_result
. Anyway, what you suggest looks like a good first step.
As far as I understand Zygote, the first choice should generally be not mutating at all
I don't know if that's the case long term. I think in the short term Buffer
is meant to be the solution.
My thought is to start by adding a macro for defining new functions that have mutating and non-mutating versions, since we have a lot of duplication there. That would use the above pattern with finalize_result
, so we would get Zygote compatibility. Then we could either modify the macro to independently create mutating and non-mutating or just drop it altogether to get functional.
For custom adjoints, in the short term we can define them using ZygoteRules.@adjoint
, but once it uses ChainRules, we can migrate to using ChainRulesCore.rrule
.
I don't know if that's the case long term. I think in the short term
Buffer
is meant to be the solution.
For one thing, Buffer
isn't a StaticArray
and therefore performance on small arrays will suffer. I wonder if we can effectively wrap Buffer
in SizedAbstractArray
or HybridArray
.
My thought is to start by adding a macro for defining new functions that have mutating and non-mutating versions, since we have a lot of duplication there. That would use the above pattern with
finalize_result
, so we would get Zygote compatibility. Then we could either modify the macro to independently create mutating and non-mutating or just drop it altogether to get functional.
Making such a macro may be quite complicated. But it could work.
For custom adjoints, in the short term we can define them using
ZygoteRules.@adjoint
, but once it uses ChainRules, we can migrate to usingChainRulesCore.rrule
.
Sounds reasonable, the more AD backends we can support the better.
For one thing,
Buffer
isn't aStaticArray
and therefore performance on small arrays will suffer. I wonder if we can effectively wrapBuffer
inSizedAbstractArray
orHybridArray
.
I haven't mixed Zygote with StaticArrays extensively, but in my simple tests, it was slower than using a standard Array
. May have just been a quirk of the function I was using.
If passed a static array, Buffer
will just wrap the output of similar
, and forward most operations to it, so I don't see why it shouldn't get the same speed-ups. Here's a simple test where it seems to do so.
julia> using Zygote, StaticArrays, BenchmarkTools
julia> function foo(x)
y = similar(x)
copyto!(y, x^2)
return y
end
foo (generic function with 1 method)
julia> function bar(x)
y = Zygote.Buffer(x)
copyto!(y, x^2)
return copy(y)
end
bar (generic function with 1 method)
julia> x = @SMatrix randn(3,3);
julia> @benchmark foo(x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 1
--------------
minimum time: 41.659 ns (0.00% GC)
median time: 43.811 ns (0.00% GC)
mean time: 58.513 ns (12.49% GC)
maximum time: 7.474 μs (98.40% GC)
--------------
samples: 10000
evals/sample: 990
julia> @benchmark bar(x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 1
--------------
minimum time: 42.524 ns (0.00% GC)
median time: 44.118 ns (0.00% GC)
mean time: 59.761 ns (12.47% GC)
maximum time: 8.164 μs (98.92% GC)
--------------
samples: 10000
evals/sample: 990
julia> x2 = collect(x);
julia> @benchmark foo(x2)
BenchmarkTools.Trial:
memory estimate: 320 bytes
allocs estimate: 2
--------------
minimum time: 127.207 ns (0.00% GC)
median time: 145.000 ns (0.00% GC)
mean time: 184.259 ns (7.84% GC)
maximum time: 5.387 μs (94.93% GC)
--------------
samples: 10000
evals/sample: 900
Making such a macro may be quite complicated. But it could work.
I looked at MacroTools
and instantly regretted this idea. I'll just change all the functions with similar_result
to add finalize_result
. Long term though, I do think we need to work out which pattern is the default, mutating or non-mutating, since that determines what we recommend users overload.
I've taken a look at Zygote.Buffer
and I have a better understanding of what it does now. One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of StaticArrays
is not designed to handle Buffer
. That could be fixed though.
Do you know why Buffer
is not a subtype of AbstractArray
?
One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of
StaticArrays
is not designed to handleBuffer
. That could be fixed though.
I have a local branch where I've begun work on this, and I'll open a breaking WIP PR.
Do you know why
Buffer
is not a subtype ofAbstractArray
?
Not specifically. It seems to be its main point though. I think the idea is to sidestep AbstractArray
's pullbacks, which requires a different type. See https://github.com/FluxML/Zygote.jl/blob/a533d9954341725b7df7b10722a9503238767917/src/lib/buffer.jl#L30-L32.
One case where it wouldn't work (at least for now) is in broadcasting, as broadcasting of
StaticArrays
is not designed to handleBuffer
. That could be fixed though.I have a local branch where I've begun work on this, and I'll open a breaking WIP PR.
Great! I have some experience with broadcasting in StaticArrays
(I've duplicated a part of it in HybridArrays
), could you let me know when you open that PR?
Do you know why
Buffer
is not a subtype ofAbstractArray
?Not specifically. It seems to be its main point though. I think the idea is to sidestep
AbstractArray
's pullbacks, which requires a different type. See https://github.com/FluxML/Zygote.jl/blob/a533d9954341725b7df7b10722a9503238767917/src/lib/buffer.jl#L30-L32.
I don't fully understand that but that apparently doesn't work with broadcasting at all? So you have to do quite a lot in that PR :slightly_smiling_face: .
julia> a = Zygote.Buffer([1.0 2.0; 3. 4.])
Zygote.Buffer{Float64,Array{Float64,2}}([1.390671161567e-309 0.0; 6.9043377050637e-310 0.0], false)
julia> a .*= 2
ERROR: MethodError: no method matching iterate(::Zygote.Buffer{Float64,Array{Float64,2}})
Closest candidates are:
iterate(::Core.SimpleVector) at essentials.jl:604
iterate(::Core.SimpleVector, ::Any) at essentials.jl:604
iterate(::ExponentialBackOff) at error.jl:214
...
Stacktrace:
[1] copyto!(::Array{Float64,1}, ::Zygote.Buffer{Float64,Array{Float64,2}}) at ./abstractarray.jl:722
[2] _collect(::UnitRange{Int64}, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Base.HasEltype, ::Base.HasLength) at ./array.jl:550
[3] collect(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./array.jl:544
[4] broadcastable(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./broadcast.jl:659
[5] broadcasted(::Function, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Int64) at ./broadcast.jl:1213
[6] top-level scope at REPL[16]:1
Great! I have some experience with broadcasting in
StaticArrays
(I've duplicated a part of it inHybridArrays
), could you let me know when you open that PR?
Sure, though I think you'll get the notification anyways. I'm going to start here, and if it looks like changes are needed to StaticArrays, we can go that route.
I don't fully understand that but that apparently doesn't work with broadcasting at all? So you have to do quite a lot in that PR 🙂 .
😧 well this will be fun.
Buffer
isn't meant to do any math, just to provide an interface for mutating mainly via setindex!
and copyto!
. In most of our mutating functions, the mutated array is only mutated at the very end of the function in what is essentially a copyto!
(making this an explicit copyto!
will be necessary), so that's easy to support. The project_XXX!
functions used at the end of exp!
and log!
are one of the exceptions. But I think we can make it work, even if it adds 1 extra allocation.
The tricky thing is making sure we've considered the edge cases with our special array-wrapping types like FVector
and ProductRepr
.
After putting some time into this, I now think that the functional approach will be cleaner and easier to support going forward. A few downsides: 1) Some loss in efficiency, but how much? 2) Optim.jl
will require manifolds to have 2 functions: project_tangent!
and retract!
. Adopting a functional pattern would likely have these mutate only after allocating. I don't know how our SizedStaticArray
s or the new HybridArray`s being used in #40 would be impacted by not mutating. I don't think these downsides are extreme, but I'd like to open it up to discussion.
One possible way forward is to have both mutating and non-mutating versions of all functions. Having only non-mutating versions is not acceptable for me as this would very significantly slow down my computations. I've originally written my FunManifolds.jl code in a non-mutating style and then I was slowly replacing it with mutating code which resulted in massive speed-ups.
One possible way forward is to have both mutating and non-mutating versions of all functions.
Ugh, that will just be so frustrating to keep synchronized. It's too bad we don't have better tools for writing general interfaces that support mutating and functional styles in base Julia. I've heard rumors that some such features might be coming for 2.0.
In the meantime, maybe getting Buffer
working and keeping the base mutating style is the best way to go. I suspect mutating code will always have some issues, but the idea would be that when using our non-mutating functions, Zygote should just work as though they didn't internally mutate.
Ugh, that will just be so frustrating to keep synchronized.
Generic tests can take care of checking that everything works right. Having both variants might be less frustrating than dealing with Buffer
though, especially for product and power manifolds.
It's too bad we don't have better tools for writing general interfaces that support mutating and functional styles in base Julia. I've heard rumors that some such features might be coming for 2.0.
I didn't know that 2.0 could have a better support for that. I wonder how that would work. Anyway, what we have looks more like a Zygote problem than a Julia problem.
In the meantime, maybe getting
Buffer
working and keeping the base mutating style is the best way to go. I suspect mutating code will always have some issues, but the idea would be that when using our non-mutating functions, Zygote should just work as though they didn't internally mutate.
Right, and mutation might not even be the biggest problem on the road to Zygote compatibility. We'll see.
It would be nice to add compatibility for Zygote. I know I personally need Zygote for its complex number support, and I'd also like to use
Manifolds
in the same code.For the most part, we can probably expect Zygote to just work, with one major block: Zygote doesn't support mutation. I don't think this will be a big problem though. Zygote offers
Zygote.Buffer
, which behaves just likesimilar
and allows mutation. All we should need to do is add something like thisEvery non-mutating function that uses
similar_result
then returns the output offinalize_result
, and Zygote should just work.Of course, this just brings Zygote support to the same level as ReverseDiff. It doesn't handle the issues raised in #17. Although it should be easier to define custom behavior for embedded manifolds with Zygote than with ReverseDiff.