FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 210 forks source link

Facilitate array construction without mutation #377

Closed sdewaele closed 4 years ago

sdewaele commented 4 years ago

Is there a recommended (and preferably generic) mechanism to construct an array y that is a function of a variable x without encountering the issue that array mutation is not supported? Conceptually, this array is constructed and not changed afterwards, i.e. is immutable. However, in practice many functions for constructing/populating an array mutate it. Of course, one can resort to defining a custom adjoints for all of these, but that is not a desirable situation for an autodiff package.

As a concrete example, I will consider a function f that constructs the matrix A out of a vector a. As background, this function computes the A matrix corresponding for the first order vector representation of a Stochastic Differential Equation (SDE).

# Desired function - autodiff does not work
function f(a::AbstractArray{T}) where {T}
  p = length(a)
  A = [zeros(T,p-1,1) Matrix{T}(I,p-1,p-1)
       -a']
end
x = [0.2,0.1,-0.3]
@show f(x)
# y = [0.0 1.0 0.0; 0.0 0.0 1.0; -0.2 -0.1 0.3]

rng = MersenneTwister(58634)
ȳ = randn(rng,3,3)

y,back = Zygote.forward(f,x)
@show back(ȳ)[1]
# ERROR: LoadError: Mutating arrays is not supported

In this case, I have been able to hack a working solution, but it ain't pretty, nor generalizable:

Atop(p) = hcat(zeros(p-1,1),Matrix(I,p-1,p-1)) 
Zygote.@nograd Atop
function f_working(a::AbstractArray{T}) where {T}
  p = length(a)
  A = [Atop(p);-a']
end
y,back = Zygote.forward(f_working,x)
@show back(ȳ)[1]
# (back(ȳ))[1] = [-0.579952; -0.51652; -0.520227]

Here, the pattern of where a elements occur in y is not too complicated, so this workaround is feasible, but the pattern can be more complex.

I hope that there is a methodology using the existing code base. If not, perhaps one idea is to define a setindex that returns a copy, instead of setindex!, thus preventing mutation. Possibly there could even be a macro that would replace all setindex! by the copying version automatically when the function is passed through Zygote.forward. Not terribly efficient perhaps, but at least a working solution, and probably not too bad for small arrays.

sdewaele commented 4 years ago

Here is another, more minimal example. A workaround solution is easier in this case; that is not the point of course.

# Desired function - autodiff does not work
function f(x::T) where {T}
  Y = zeros(T,2)
  Y[2] = x
  return Y
end
x = 0.3
@show f(x)
# f(x) = [0.0, 0.3]

ȳ = [0.5,-1]
y,back = Zygote.forward(f,x)
@show back(ȳ)[1]
# ERROR: LoadError: Mutating arrays is not supported
mcabbott commented 4 years ago

This is the core of your first issue — it just happens that Matrix{T}(I, size…) is mutating internally, which ought to be hidden from Zygote:

julia> h(a) = Matrix{Float64}(I,2,2);

julia> gradient(x -> sum(h(x)), rand(2))
ERROR: Mutating arrays is not supported

In the second case you are honestly mutating Y, which isn’t supported, but as you say seems trivial to work around here.

sdewaele commented 4 years ago

Thanks for your comments!

The fact that just the construction of the unit matrix fails is an additional example of this issue. This is however not the only issue with this example; if I define an eye function, add Zygote.nograd eye and they plug it into the desired function f, it still fails.

I hope it is clear that fact that the second example can be worked around easily is not a solution. I have provided this example only as a minimal example for the issue (although is seems there are also additional issues with concatenation). If you want to construct a more complex matrix, there is no longer a trivial workaround, as the original SDE example begins to show.

I hope this issue can be resolved, because it could capture a lot of cases where zygote currently fails because of matrix mutation, which are actually cases where all we needed is a methodology to construct an immutable matrix. This would be a big step forward towards enabling the usage of zygote also in Bayesian inference/probabilistic programming, as opposed to being limited only to neural networks.

MikeInnes commented 4 years ago

There are two ways to support functions that are semantically non-mutating, but mutate things internally: (1) support (efficient) differentiation of mutation or (2) add custom adjoints for those functions to hide the mutation. (1) would obviously be ideal, but failing that, (2) seems like it would fix your issues here; so long as we have coverage of the standard libraries, your SDE example will work fine, so that's what our next step should be.

It seems like you're advocating for some third option here, but I don't think there's a simpler way of solving this just for array constructors that isn't equivalent to solving mutation generally.

sdewaele commented 4 years ago

Perhaps option 3) is to create a few zygote helper functions for array construction, such as the (copying) setindex that I mentioned. It still requires custom-written code, but I think that is much better than requiring custom written adjoints!

mcabbott commented 4 years ago

You may like BangBang.jl, which defines functions like setindex!! which only mutate when they can. It looks like it aims to take Zygote into account, although strangely this example fails with a vector not a tuple:

julia> using BangBang

julia> @! gradient(x -> sum(push!(x, 1)), (1,2,3))
((1, 1, 1),)
sdewaele commented 4 years ago

Okay, construction of the matrix works like with the following code, and it does not deviate too much from the original:

eye(T,p) = Matrix{T}(I,p,p)
Zygote.@nograd eye
function f(a::AbstractArray{T}) where {T}
  p = length(a)
  A = [hcat(zeros(T,p-1,1),eye(T,p-1)); -a']
end
# works!

@mcabbott - here I used your observation that definition of the unit matrix in itself was causing a problem. I am happy with this solution for the SDE problem at this point.

Note that concatenation as in the original code still fails:

function f(a::AbstractArray{T}) where {T}
  p = length(a)
  A = [zeros(T,p-1,1) eye(T,p-1); -a']
end
# ERROR: LoadError: Mutating arrays is not supported

If there is no intention to define e.g. setindex as described, or fix the concatenation problem, it is fine to close this issue as far as I am concerned. Note, though, that a setindex type of solution would still be beneficial for cases where the parameters populate a matrix in a more complex pattern than my SDE example.

MikeInnes commented 4 years ago

Having an out-of-place setindex would be useful, but it's strange to think of it as an alternative to having adjoints for the stdlib. The experience for users is ultimately the same: you have to write your code in terms of non-mutating functions that Zygote supports (possibly including setindex).

We can't automatically rewrite setindex! to setindex, which again would be equivalent to solving the mutation problem in general. But I'm sure we could figure out where a setindex function should live.

mcabbott commented 4 years ago

According to @which [[1] [2]; [3,4]'] it looks like your example calls hvcat, which has a gradient only handling the case of numbers, unlike the simpler cat functions.

sdewaele commented 4 years ago

@mcabbott - good to know how to analyze that concatenation code! I was not sure how to do it.

@MikeInnes - It would probably be useful to have the non-mutating setindex (with an adjoint defined of course). As said, rewriting a part of your code to use it is still easier than writing the adjoint for the same.

MikeInnes commented 4 years ago

Right; I'm definitely not advocating that users should have to write adjoints for their own code. Where it's not easy to write a function using non-mutating stdlib functions, adding new functions to fill those gaps is reasonable. But that wasn't the case in your SDE example; there we just need to expand support for the stdlib.

NMUrban commented 4 years ago

What are the implications here for arrays constructed through comprehensions? For example, when working with Gaussian processes I'd like to construct a covariance function like Σ = σ^2 * [exp(-((x-x′)/λ)^2) for x in X, x′ in X], and differentiate through computations involving Σ with respect to λ. Is there a way for this to work? (I know there are workarounds for this particular function; I'm asking generally.)

mcabbott commented 4 years ago

I’m sure you know this, but since I wanted to see what goes wrong: Comprehensions with Iterators.ProductIterator don’t seem to work right now, but surely they could be made to.

julia> using Zygote
julia> X = rand(4);
julia> Σ(λ) = [exp(-((x-x′)/λ)^2) for x in X, x′ in X]
Σ (generic function with 1 method)

julia> gradient(λ -> sum(Σ(λ)), 0.2)
ERROR: Need an adjoint for constructor Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}}. Gradient is of type Array{Tuple{Float64,Float64},2}

julia> gradient(λ -> sum([x^2/λ for x in X]), 0.2) # simple generator OK
(-46.45199007830025,)
NMUrban commented 4 years ago

Yes, that's the error. Should I open a separate issue for comprehensions/ProductIterator, or is this a good place to track it?

sdewaele commented 4 years ago

I wrote a copying setindex with adjoint:

setindex(A,X,inds...) = setindex!(copy(A),X,inds...)

@adjoint setindex(A,X,inds...) = begin
  B = setindex(A,X,inds...)
  adj = function(Δ)
    bA = copy(Δ)
    bA[inds...] .= zero(eltype(A))
    bX = similar(X)
    bX[:] = Δ[inds...]
    binds = fill(nothing,length(inds))
    return bA,bX,binds...
  end
  B,adj
end

See this gist for extended code with successful finite difference comparisons.

sethaxen commented 4 years ago

Perhaps to support higher order differentiation, it should use Zygote.Buffer instead. See e.g. https://github.com/FluxML/Zygote.jl/issues/376#issuecomment-546749672.

sdewaele commented 4 years ago

Thanks!!! I was not aware of Zygote.Buffer. It is probably the solution for this issue. I will test is for my use cases.

baggepinnen commented 4 years ago

There are two ways to support functions that are semantically non-mutating, but mutate things internally: (1) support (efficient) differentiation of mutation or (2) add custom adjoints for those functions to hide the mutation. (1) would obviously be ideal, but failing that, (2) seems like it would fix your issues here; so long as we have coverage of the standard libraries, your SDE example will work fine, so that's what our next step should be.

Unfortunately, option (2) may help a user make his own code differentiate, but any use of an ecosystem package implies a close to zero percent chance code will differentiate (my own experience at least) unless that package is Zygote aware.

I've come to understand that supporting mutation is a difficult thing, but if it is at all doable, I think it would be the only chance of reaching a state where f'(x) just works (actually "differentiate all the things").

MikeInnes commented 4 years ago

Yes, I think that's uncontroversial; it's something we'd really like to have, but unfortunately very difficult to do in a robust way with Julia's current semantics, so there's ongoing discussion about how we might be able to improve that.

FWIW, it's not necessarily a Zygote-specific issue; code that doesn't support immutable arrays also won't work with e.g. StaticArrays or in general with CuArrays for example. Zygote works with DiffEq only because it has support for those other kinds of arrays as well. So it doesn't seem too unreasonable to ask polished Julia packages to provide that option.

baggepinnen commented 4 years ago

FWIW, it's not necessarily a Zygote-specific issue; code that doesn't support immutable arrays also won't work with e.g. StaticArrays or in general with CuArrays for example. Zygote works with DiffEq only because it has support for those other kinds of arrays as well. So it doesn't seem too unreasonable to ask polished Julia packages to provide that option.

That's a good point! There's one case where this actually works for StaticArrays and Zygote currently does not. A function which is semantically nonmutating but internally mutating is likely to work with StaticArrays if it makes use of the similar function, that creates a mutable container, whereas if it creates the container using, e.g., zeros it will fail. Would it be possible to make this important subset of functions work automatically? Automatically in the sense that a compiler pass would swap out similar for Buffer and perform the copy at the right place, or something like that.

MikeInnes commented 4 years ago

Overriding similar to return a Buffer is pretty easy. The hard part is doing the copy in the right place. This isn't possible in general, but I could imagine having heuristics that catch a lot of important cases (e.g. if similar is called once and its return value only escapes via indexing and function return; add copy before returning). It could still technically break code, but might fix enough cases to be worthwhile.

baggepinnen commented 4 years ago

I just learned that Base does define an out-of-place setindex

help?> Base.setindex
  # 4 methods for generic function "setindex":
  [1] setindex(x::Tuple, v, i::Integer) in Base at tuple.jl:31
  [2] setindex(a::SHermitianCompact{N,T,L}, x, i::Int64) where {N, T, L} in StaticArrays at /home/fredrikb/.julia/packages/StaticArrays/DBECI/src/SHermitianCompact.jl:116
  [3] setindex(a::StaticArray, x, index::Int64) in StaticArrays at /home/fredrikb/.julia/packages/StaticArrays/DBECI/src/deque.jl:76
  [4] setindex(a::StaticArray, x, inds::Int64...) in StaticArrays at /home/fredrikb/.julia/packages/StaticArrays/DBECI/src/deque.jl:89

which has methods for static arrays etc. but not for Base.Array. It would be nice to have the adjoint for this, and possibly pirate the function to define a method for Base.Array as well like suggested above.

sdewaele commented 4 years ago

I am curious how Zygote's technology compares to Stan's autodiff library. Notably, I wonder how its technology should be categorized according to the distinction made in differential programming between static graph and dynamic graph.

More specific to this issue, I know that Stan does handle mutable arrays. Is there something that can be learned from this library?

chriselrod commented 4 years ago

Stan uses a dual number approach, where each var is a pointer to a dual (value, adjoint), so you could think of it as working more similarly to ForwardDiff.jll, except that it is reverse mode and thus better suited to all the N-to-1 functions like log probability density functions.

Operating on the vars pushes onto a global call stack, which is used for the reverse pass.

What's the difference between that and Flux's tracked vars? Could Zygote pullbacks, upon detecting mutation not supported by the source to source AD, automatically switch to such an approach as a fallback?

sdewaele commented 4 years ago

Thanks @chriselrod! I am still not sure in which category I should place it:

1) Static, compiled graph Stan generates C++ code that is subsequently compiled for a given Stan script. It seems like it puts Stan solidly in this category..?

However, there is also evidence for option 2:

2) Operator overloading, dynamic graph

Perhaps the conclusion is that Stan has a similar approach as Zygote, being compiled while still allowing control flow?

chriselrod commented 4 years ago

I think Stan is solidly in the operator overloading, dynamic graph category.

They compile the Stan language into C++, but (AFAIK) don't do any sort of source to source for AD, but rely entirely on the operator overloading. Calling a Julia function using ForwardDiff causes LLVM to statically compile the appropriate methods, but the approach is no less dynamic.

sdewaele commented 4 years ago

Yes, you are right. Thanks for helping me to connect the dots. I am starting to understand the landscape of AD technology.

sdewaele commented 4 years ago

Closing this issue because it is addressed by using Buffer.

@NMUrban please open a separate issue for your comprehension issue. I think the core devs are best helped by a targeted issues.

arnavs commented 4 years ago

@sdewaele Would it be possible to produce a documentation example for the Buffer solution? Having some trouble with this bug myself...

For my specific use-case, we have something like

loss() = begin 
   z = rand(50) # produce a batch of random points 
   ∂(x) = Zygote.gradient(ϕ,x)
   ∂∂(x) = Zygote.hessian(ϕ, x) 
   return sum(abs2, (f(x, ∂(x), ∂∂(x)) for x in z))
end

I think Buffer might be the solution. I tried it with a for loop, but that also didn't work.

sdewaele commented 4 years ago

Could you simplify our code to isolate the error? Do you take the derivative of this loss function with respect to some variables that are not shown? Right now your problem could arise in many places. BTW, I think that support for Hessians is experimental at this time.

arnavs commented 4 years ago

Thanks @sdewaele, will give that a go.

Currently we've decided to just go with finite differencing to sidestep some of these issues, but I expect we'll come back to this soon. I'll type up a MWE or something at that point.