Closed kbarros closed 2 years ago
I think this is because the output of r * h
is dense thus as the adjoint *
defines
@adjoint function(a::AbstractVecOrMat * b::AbstractVecOrMat)
return a * b, function(Δ)
return (reshape(Δ * b', size(a)), reshape(a' * Δ, size(b)))
end
end
This returns a dense matrix since Δ
(dense) multiplies sparse is dense, which make sense. I don't see the benefit of converting it back and forth...
Yes, I'm looking for something like:
@adjoint function(a::SparseMatrixCSC * b::AbstractVecOrMat)
return a * b, function(Δ)
return (
# dense-dense matrix multiplication with sparse output!
mat_mul_with_sparse_output_in_shape_of(Δ, b', a),
# same as before
reshape(a' * Δ, size(b)))
end
end
The new thing here is mat_mul_with_sparse_output_in_shape_of()
. It will calculate matrix elements of Δ * b'
only if those are present in the sparsity structure of a
. This is the dual operation to a * b
.
The motivation would be efficiency. I'm trying to explore whether Zygote would be able to simplify calculations of the type in this method https://arxiv.org/abs/1711.10570 (see Appendix B for the "reverse-mode differentiation" transformation that I performed manually).
I don't think we can do this by default; a sparse matrix is just an efficient encoding for the equivalent dense one, and the two should give equivalent gradients. Otherwise we'll get incorrect gradients if e.g. someone sparsifies a dense matrix during a forward pass.
I think the best way to handle this would be through the optimiser / update step (this would be at the Flux level rather than Zygote). In cases where the dense matmul would be really inefficient, we could add a separate operation that drops gradients explicitly.
I don't think we can do this by default; a sparse matrix is just an efficient encoding for the equivalent dense one, and the two should give equivalent gradients. Otherwise we'll get incorrect gradients if e.g. someone sparsifies a dense matrix during a forward pass.
I'm not sure this is correct @MikeInnes . I'm pretty sure that you only ever need to keep track of the gradient of the elements of a sparse matrix that are non-zero i.e. the adjoint should be a named tuple where one of the elements has the gradient info w.r.t. the non-zero elements. It's just the same idea as the adjoint for e.g. a Diagonal
matrix.
I agree that the above suggestion isn't quite right, since you need to return a named tuple rather than an AbstractMatrix
, but I think the general idea of just implementing a custom adjoint is fundamentally sound and could be made to work.
It seems hard to square that with the fact that the current gradient we calculate is not sparse; if you store less information than that you must be dropping gradients. In particular I'm thinking of cases like
gradient(x -> sum(x), [1, 0]) # [1, 1]
gradient(x -> sum(sparse(x)), [1, 0]) # ?
sparse
should behave mathematically like the identity function here, but if we simply drop the gradient of 0
elements then the overall gradient will be [1, 0]
instead. (Despite the fact that you've actually asked for the gradient of a dense vector here, and in disagreement with finite differences.)
It would be completely consistent to have a custom sum adjoint that returns a named tuple representing the adjoint w.r.t. the sparse matrix. In the case you've provided above, it happens to be that every element of the sparse matrix is non-zero, but the point stands. This NamedTuple
would be propagated back to the (presumably custom) adjoint for sparse
, which would know how reconstruct the appropriate array of ones.
See next comment
It's kind of analogous to
gradient(x -> sum(Diagonal(x)), randn(2)) # ([1.0, 1.0],)
We could define a custom adjoint for sum
that knows about Diagonal
matrices and return a NamedTuple
(diag=Ones(2),)
(or something like that).
Ok here's a case which I think is clearer:
gradient(x -> [2, 3, 4]'sparse(x), [1, 0, 0]) # [2, 3, 4]
What would the named tuple gradient be? Seems to me that you if you assume all zeros have the same gradient, you can't get the right result here, and if you don't then you have something equivalent to a dense matrix. But perhaps I'm missing something about the proposal.
Ah, sorry, I forgot how sparse
works :grimacing:. I agree with you that it's problematic that the sparsity pattern changes here if you perturb the inputs.
It does, however, seem unfortunate to me that we wouldn't be able to exploit sparsity in AD, which is what your above example suggests that we can't safely do without breaking stuff. As I see it, there are two basic options:
SparseArrays
we really want to support. Is there a "safe" subset where we get to keep the efficiency benefits without risking accidentally dropping gradients that we care about?AbstractFill
) is broken when broadcast
ing / map
ing impure functions, leading to unintuitive behaviour, which ultimately turns out to be (arguably) the right way to go. There's actually precedence for this already in the behaviour of sparse arrays e.g. thisAre there any other options that you can think of?
I don't think it's inherently wrong to view sparse matrices as "fixing" values to zero (rather than just efficiently storing values that happen to be zero), so long as it doesn't result in incorrect non-sparse gradients, which may be doable. That's really the core of the issue here, since we take can take advantage of natural sparsity just fine – it happens that in most cases the gradient of a sparse matrix is dense, but if it isn't we can happily return a sparse adjoint (possibly with a different sparsity pattern to the primal).
A couple of other ways we could take advantage of fixed sparsity structure:
AbstractArray
, but with undefined gradients for the 0 components of the primal. Then all programs either give correct, efficient gradients or error out.I kind of like the sound of the new type; if there's a fundamental semantic choice here then it's a natural solution, even if it's a bit strange given that the forward pass semantics are the same. If possible it's better for gradient(f, sparse(x))
not to give different results compared to gradient(f, x)
.
Tangent: For reference, it might be helpful to show an equivalent issue for complex numbers: that is, real->real functions which have complex gradients.
julia> f(x) = imag(x*(1+2im))
f (generic function with 1 method)
julia> f(1)
2
julia> f'(1)
2 + 1im
An Int
can be thought of as a "sparse" representation of a float, or complex, or quarternion..., in which case this makes sense; but it can also be thought of as a semantic restriction (to integer, in which case the gradient is zero, or to real, in which case the gradient is real). The inherent arbitrariness of that decision is what makes me think it should be part of the optimiser, but of course in the sparse arrays case there are efficiency concerns that do change the situation somewhat.
I just want to bring your attention to the fact that sparsity is key in a lot of large scale optimization and packages like JuMP and Convex.jl has, as far as I know, their own AD supporting sparse Hessians. Ref https://github.com/mlubin/ReverseDiffSparse.jl It would be nice if Zygote handled cases where the Hessian has a very clear sparsity structure, e.g., trend filtering or optimal control etc.
Sparse hessians end up looking very different from sparse gradients: the game there is to calculate the hessian via the smallest number of (dense) AD passes. Zygote will support this but it'll do so by leaving the sparsity structure / graph colouring analysis to another tool, which can work out how best to call Zygote and ForwardDiff to calculate the Hessian.
@kbarros The discussion above might come off as negative, but I think we actually have a good path forward, and Zygote both can and should help in cases like yours. We just need an implementation of mat_mul_with_sparse_output_in_shape_of
and we can start prototyping something here. Even if it's opt-in to begin with that helps us stress test it and figure out what the default behaviour should be.
@MikeInnes thanks for the update. I like your thoughts about defining a new type to "opt in" to a more optimized behavior.
I don't feel like this discussion is negative at all. Actually, it highlights an interesting difference in thinking:
findnz()
of the SparseMatrixCSC
input but I agree this could be generalized. Note that the specified sparsity pattern does not necessarily correspond to nonzero values in the input.)I now see that this second way of thinking might lead to unexpected results for many users. However, it has the nice property that the flops for the forward and reverse calculations are essentially the same. A new type for the sparse output of the autodiff now makes sense to me. This type could then warn: "you are trying to access a matrix element for which the derivative has not been computed."
By the way, here's more about where I'm coming from. The sparse matrix H(x)
might represent a quantum Hamiltonian, parameterized by a collection of numbers x
. We will have a scalar function f(H(x))
to calculate a physical observable. It is often useful to calculate the gradient d f(H(x)) / dx
, e.g., for optimizing the parameters x
. In its backward pass, Zygote would first calculate the gradient df / dH
(that, hopefully, we can represent by a sparse matrix) and then apply the chain rule to get df/dx = \sum_{i,j} (df/dH)_{ij} (dH/dx)_{ij} )
. Importantly, dH/dx
is a sparse matrix by construction, which explains why we only need the corresponding sparse set of elements of df/dH
.
None of this is blocking my work. At this point I'm mostly curious about the potential capabilities of Zygote.
Off topic question: What is the best forum to ask more questions about Zygote? I'm excited about what you folks are doing, and would like to explore more.
Yes, spot on and that's a good way of putting it. That use case is a good motivator for this; converting H(x)
to "fixed sparse" just before calling f
seems like a reasonable way to express this.
Issues here are totally fine for this kind of discussion. For more casual stuff I highly recommend the Julia slack (particularly the #autodiff channel) where a bunch of us hang out most of the time.
Just came across this blog post: http://juliadiffeq.org/2019/06/06/StiffSDEs.html
It's exciting that they mention that:
Sounds to me like what's already being discussed in this issue? Looking forward to learning more.
Is this issue still a problem? I tried with the latest Zygote v0.6.32, and the following code correctly returns a SparseMatrixCSC
with the same sparsity as B
:
using SparseArrays, Zygote, LinearAlgebra
n = 5
A = rand(n, n)
B = sprandn(n, n, 0.1)
gradient(B -> tr(A * B), B)[1]
sparse-sparse matmul also works fine:
C = sprandn(n, n, 0.1)
gradient(B -> tr(C * B), B)[1]
Can this issue be closed? Any remaining issues or edge cases?
None that I know of, though I suspect the reason this is still open is because nobody has done a comprehensive check. Given there are other issues like https://github.com/FluxML/Zygote.jl/issues/931 covering sparse arrays and this one is specific to output types, it should be safe to close.
Is this issue still a problem?
Any remaining issues or edge cases?
Since it's not mentioned above, what's solving the issue is that ProjectTo
enforces sparsity. The code is here, it's not very efficient and awaits someone with a need to improve it:
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/main/src/projection.jl#L507-L603
This is applied after generic rules such as the one for *
, which will often produce a full matrix gradient first. It would be possible to write more specialised rules to preserve sparsity earlier, but again someone has to do it.
Consider the code:
Currently this returns a dense
Array
. Is it possible to have it return an output with same same sparsity structure as the inputh
?It seems that the gradient with respect to a sparse matrix should generally be a sparse matrix.
Note that users could always revert to the current behavior by explicitly promoting
h
to a dense matrix,gradient(h -> tr(r * h), Array(h))[1]