Open CarloLucibello opened 2 years ago
Huh, I'm not sure what's happening here, but it seems to be a Zygote issue.
i.e. using the rrule directly
julia> y, pb = rrule(*, x, a)
julia> dstar, dx, da = unthunk.(pb(ones(3, 3)))
(NoTangent(), [0.29777006193839184 0.6709424987233098 0.7995858130348517; 0.29777006193839184 0.6709424987233098 0.7995858130348517; 0.29777006193839184 0.6709424987233098 0.7995858130348517], sparse([2, 3, 1, 1], [1, 1, 2, 3], [1.2757443358356997, 1.3035633062756793, 1.5875676437041135, 1.5875676437041135], 3, 3))
julia> dx
3×3 Matrix{Float64}:
0.29777 0.670942 0.799586
0.29777 0.670942 0.799586
0.29777 0.670942 0.799586
so it looks like Zygote is doing the conversion from Matrix to SparseMatrixCSC for some reason?
That said, there is a reason we don't project sparse arrays to dense arrays. They are valid gradients (i.e. they live in the subspace of dense matrices), and allocating a whole matrix for when the gradient is actually sparse is quite wasteful.
when the gradient is actually sparse is quite wasteful
Right. But here it's a sparse type but with all entries.
I think this is coming from the use of FillArrays for the gradient of sum. If you broadcast something to make a dense array, they you get the expected behaviour, since dx
is dense sparse (instead of Fill sparse), and da
is sparsified by projection:
julia> dx, da = gradient((x, a) -> sum(abs, x * a), x, a);
julia> dx
3×3 Matrix{Float64}:
1.2198 0.0890081 0.470331
1.2198 0.0890081 0.470331
1.2198 0.0890081 0.470331
julia> da
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
⋅ 1.29006 1.29006
1.65091 ⋅ ⋅
0.697421 ⋅ ⋅
julia> using FillArrays
julia> Fill(1.0,3,3) * a
3×3 SparseMatrixCSC{Float64, Int64} with 9 stored entries:
0.559339 0.830073 0.389732
0.559339 0.830073 0.389732
0.559339 0.830073 0.389732
Maybe projection should fix this. But are there cases where you'll get a really sparse gradient for a dense matrix, and save work by keeping it? Those will be broken if it decides based on types along.
Alternatively, should the rule for sparse * dense
be smarter? Could it find da
more efficiently than by projecting the dense result?
Also, does this problem happen in real use? The Fill
gradient for sum
tends to cause weird effects, but usually only in quick examples you invent to test or time things. There's some chance we should just remove it, as saving 1 allocation doesn't seem such a big deal. If you sum
a broadcast (or matmul) then there are other wasteful things as big being done.
Note BTW that the types look right for sparse * vector:
julia> da, dv = gradient((a, v) -> sum(a * v), a, v);
julia> da
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
⋅ 0.646975 0.863441
0.304581 ⋅ ⋅
0.304581 ⋅ ⋅
julia> dv
3-element Vector{Float64}:
0.559339444124771
0.8300725860513986
0.3897323622809169
Gradients with respect to dense arrays of functions involving sparse operations often return sparse arrays instead of dense ones (see examples below).
Since these gradients are dense, maybe is more appropriate to return dense arrays instead (so that they also match tye primal type). Should we add some specializations to
ProjectTo
to handle this?