FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 211 forks source link

sparse arrays returned as gradients of dense arrays #1174

Open CarloLucibello opened 2 years ago

CarloLucibello commented 2 years ago

Gradients with respect to dense arrays of functions involving sparse operations often return sparse arrays instead of dense ones (see examples below).

Since these gradients are dense, maybe is more appropriate to return dense arrays instead (so that they also match tye primal type). Should we add some specializations to ProjectTo to handle this?

julia> using Zygote, SparseArrays

julia> s, t = [1, 1, 2, 3], [2, 3, 1, 1];

julia> e = rand(4);

julia> a = sparse(s, t, e)
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
  ⋅        0.938323  0.355447
 0.724286   ⋅         ⋅ 
 0.129109   ⋅         ⋅ 

julia> x = rand(3, 3);

julia> x * a   # dense * sparse -> dense 
3×3 Matrix{Float64}:
 0.615101  0.00658076  0.00249286
 0.710266  0.0961539   0.0364241
 0.31373   0.101995    0.0386367

julia> dx, da = gradient((x, a) -> sum(x * a), x, a);

julia> dx # SHOULD BE A DENSE ARRAY INSTEAD?
3×3 SparseMatrixCSC{Float64, Int64} with 9 stored entries:
 1.29377  0.724286  0.129109
 1.29377  0.724286  0.129109
 1.29377  0.724286  0.129109

julia> da  # this is ok
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
  ⋅       0.218187  0.218187
 1.921     ⋅         ⋅ 
 1.91889   ⋅         ⋅ 

julia> gradient(e -> sum(sin.(sparse(s, t, e))), e)[1]  # SHOULD BE A DENSE ARRAY INSTEAD?
4-element SparseVector{Float64, Int64} with 4 stored entries:
  [1]  =  0.591141
  [2]  =  0.937491
  [3]  =  0.748973
  [4]  =  0.991677
mzgubic commented 2 years ago

Huh, I'm not sure what's happening here, but it seems to be a Zygote issue.

i.e. using the rrule directly

julia> y, pb = rrule(*, x, a)

julia> dstar, dx, da = unthunk.(pb(ones(3, 3)))
(NoTangent(), [0.29777006193839184 0.6709424987233098 0.7995858130348517; 0.29777006193839184 0.6709424987233098 0.7995858130348517; 0.29777006193839184 0.6709424987233098 0.7995858130348517], sparse([2, 3, 1, 1], [1, 1, 2, 3], [1.2757443358356997, 1.3035633062756793, 1.5875676437041135, 1.5875676437041135], 3, 3))

julia> dx
3×3 Matrix{Float64}:
 0.29777  0.670942  0.799586
 0.29777  0.670942  0.799586
 0.29777  0.670942  0.799586

so it looks like Zygote is doing the conversion from Matrix to SparseMatrixCSC for some reason?

That said, there is a reason we don't project sparse arrays to dense arrays. They are valid gradients (i.e. they live in the subspace of dense matrices), and allocating a whole matrix for when the gradient is actually sparse is quite wasteful.

mcabbott commented 2 years ago

when the gradient is actually sparse is quite wasteful

Right. But here it's a sparse type but with all entries.

I think this is coming from the use of FillArrays for the gradient of sum. If you broadcast something to make a dense array, they you get the expected behaviour, since dx is dense sparse (instead of Fill sparse), and da is sparsified by projection:

julia> dx, da = gradient((x, a) -> sum(abs, x * a), x, a);

julia> dx
3×3 Matrix{Float64}:
 1.2198  0.0890081  0.470331
 1.2198  0.0890081  0.470331
 1.2198  0.0890081  0.470331

julia> da
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
  ⋅        1.29006  1.29006
 1.65091    ⋅        ⋅ 
 0.697421   ⋅        ⋅ 

julia> using FillArrays

julia> Fill(1.0,3,3) * a
3×3 SparseMatrixCSC{Float64, Int64} with 9 stored entries:
 0.559339  0.830073  0.389732
 0.559339  0.830073  0.389732
 0.559339  0.830073  0.389732

Maybe projection should fix this. But are there cases where you'll get a really sparse gradient for a dense matrix, and save work by keeping it? Those will be broken if it decides based on types along.

Alternatively, should the rule for sparse * dense be smarter? Could it find da more efficiently than by projecting the dense result?

Also, does this problem happen in real use? The Fill gradient for sum tends to cause weird effects, but usually only in quick examples you invent to test or time things. There's some chance we should just remove it, as saving 1 allocation doesn't seem such a big deal. If you sum a broadcast (or matmul) then there are other wasteful things as big being done.

Note BTW that the types look right for sparse * vector:

julia> da, dv = gradient((a, v) -> sum(a * v), a, v);

julia> da
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
  ⋅        0.646975  0.863441
 0.304581   ⋅         ⋅ 
 0.304581   ⋅         ⋅ 

julia> dv
3-element Vector{Float64}:
 0.559339444124771
 0.8300725860513986
 0.3897323622809169