Open gdalle opened 1 week ago
I think it's a case of local sparsity detection, because the Ax
array has many coefficients equal to 0
but it is not stored in a sparse format (or even TriDiagonal
)
Made a version with sparse arrays which has a better chance of working, but now it runs into an error in the SparseMatrixCSC
constructor. I think that one is fixed on main
though.
Yup, it works on main
with the sparse conversion, and now we even find much fewer nonzeros than Symbolics. All we have to do is check that they are indeed nonzeros.
Now here's something interesting: if I replace sparse(Array(Tridiagonal(...
with sparse(Tridiagonal(...
in the constructor ReactionDiffusionCache2
, I get the same number of nonzeros as Symbolics. Questions:
sparse
/ SparseMatrixCSC
overload of SparseConnectivityTracer behaving differently on AbstractArray
s such as TriDiagonal
?Why is the
sparse
/SparseMatrixCSC
overload of SparseConnectivityTracer behaving differently
Since Tracers have no primal value, sparse
considers non-empty tracers instead of non-zero values. Maybe another method gets called on Tridiagonal
?
MWE:
julia> using LinearAlgebra, SparseArrays, SparseConnectivityTracer
julia> function f1(x::AbstractVector{T}) where {T}
n = length(x)
A = sparse(Tridiagonal(ones(T, n - 1), ones(T, n), ones(T, n - 1)))
return A * x
end
f1 (generic function with 1 method)
julia> function f2(x::AbstractVector{T}) where {T}
n = length(x)
A = sparse(Matrix(Tridiagonal(ones(T, n - 1), ones(T, n), ones(T, n - 1))))
return A * x
end
f2 (generic function with 1 method)
julia> ADTypes.jacobian_sparsity(f1, ones(10), TracerSparsityDetector())
10×10 SparseMatrixCSC{Bool, Int64} with 28 stored entries:
1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ 1 1 1 ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 1 ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 1
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1
julia> ADTypes.jacobian_sparsity(f2, ones(10), TracerSparsityDetector())
10×10 SparseMatrixCSC{Bool, Int64} with 0 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
The relevant lines of code: https://github.com/adrhill/SparseConnectivityTracer.jl/blob/ac94586c854726370a0e464ee62074774295a929/src/overloads/arrays.jl#L199-L226
Most notably L208, L217.
I used StridedMatrix
to stick to the original method.
Maybe another method is required for Tridiagonal
?
Why not AbstractMatrix
? I think that would solve it
I think the issue is that our method isn't called. Instead some other more specific method from SparseArrays.jl is called:
SparseMatrixCSC{Tv,Ti}(M::Tridiagonal{Tv})
If this suspicion turns out to be correct, making our signatures less specific would not help.
If this suspicion turns out to be correct, making our signatures less specific would not help.
I think it would at least error instead of silently returning the wrong thing, because a method ambiguity would be triggered. Right now, none of the methods we define applies to Tridiagonal
I think it would at least error instead of silently returning the wrong thing
No, that would just always call methods from SparseArrays which may succeed. I think we might be seeing such a "random silent success" here.
Yeah you're right, this is not enough:
julia> using LinearAlgebra
julia> struct T end
julia> LinearAlgebra.det(::AbstractMatrix{T}) = 0
julia> det(Diagonal([T(), T()]))
ERROR: MethodError: no method matching det(::T)
Note that on the latest release (not main
), the second one errors, which is much better:
julia> ADTypes.jacobian_sparsity(f2, ones(10), TracerSparsityDetector())
ERROR: Function iszero requires primal value(s).
A dual-number tracer for local sparsity detection can be used via `local_jacobian_pattern`.
Stacktrace:
⋮ internal @ SparseConnectivityTracer, Base, SparseArrays, Unknown
[18] sparse
@ /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-grannysmith-C07ZM05NJYVY.0/build/default-grannysmith-C07ZM05NJYVY-0/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:1006 [inlined]
[19] f2(x::Vector{SparseConnectivityTracer.GradientTracer{BitSet}})
@ Main ./REPL[51]:3
⋮ internal @ SparseConnectivityTracer
[22] jacobian_sparsity(f::Function, x::Vector{Float64}, ::TracerSparsityDetector{BitSet, Set{Tuple{Int64, Int64}}})
@ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/QlV0S/src/adtypes.jl:42
Use `err` to retrieve the full stack trace.
I think we should remove the SparseMatrixCSC
constructor overload
Per our discussion on Slack, overloading SparseArrays._iszero
would be an option:
But independently, can we please figure out why this happens with the current version of main
? I think it would be instructive regardless.
julia> using SparseConnectivityTracer, LinearAlgebra, SparseArrays
julia> T = SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int,BitSet}}
SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}
julia> n = 3
3
julia> A1 = sparse(Tridiagonal(ones(T, n - 1), ones(T, n), ones(T, n - 1)))
3×3 SparseMatrixCSC{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}, Int64} with 7 stored entries:
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
⋅
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
⋅ GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
GradientTracer{IndexSetGradientPattern{Int64, BitSet}}()
julia> A2 = sparse(Matrix(Tridiagonal(ones(T, n - 1), ones(T, n), ones(T, n - 1))))
3×3 SparseMatrixCSC{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}, Int64} with 0 stored entries:
⋅ ⋅ ⋅
⋅ ⋅ ⋅
⋅ ⋅
Comes from the SciML benchmark case suggested in https://github.com/SciML/SciMLBenchmarks.jl/pull/989#issuecomment-2195664712