adrhill / SparseConnectivityTracer.jl

Fast operator-overloading Jacobian & Hessian sparsity detection.
MIT License
26 stars 2 forks source link

StackOverflow when creating `SparseMatrixCSC` with tracers as values #108

Closed gdalle closed 3 months ago

gdalle commented 4 months ago
julia> using ADTypes, SparseConnectivityTracer

julia> using LinearAlgebra, SparseArrays

julia> detector = TracerLocalSparsityDetector()
TracerLocalSparsityDetector{BitSet, Set{Tuple{Int64, Int64}}}()

julia> f(x) = det(spdiagm(x))
f (generic function with 1 method)

julia> g(x) = logdet(spdiagm(x))
g (generic function with 1 method)

julia> ADTypes.hessian_sparsity(f, randn(3), detector)
3×3 SparseMatrixCSC{Bool, Int64} with 6 stored entries:
 ⋅  1  1
 1  ⋅  1
 1  1  ⋅

julia> ADTypes.hessian_sparsity(g, randn(3), detector)
ERROR: StackOverflowError:
Stacktrace:
  [1] Array
    @ ./boot.jl:477 [inlined]
  [2] Array
    @ ./boot.jl:486 [inlined]
  [3] zeros
    @ ./array.jl:636 [inlined]
  [4] zeros
    @ ./array.jl:632 [inlined]
  [5] Dict{Tuple{Int64, Int64}, Nothing}()
    @ Base ./dict.jl:70
  [6] Set
    @ ./set.jl:45 [inlined]
  [7] myempty
    @ ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/tracers.jl:5 [inlined]
  [8] myempty
    @ ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/tracers.jl:171 [inlined]
  [9] zero
    @ ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/conversion.jl:66 [inlined]
 [10] float(A::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{…}}}}})
    @ Base ./float.jl:1120
 [11] float
    @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:983 [inlined]
 [12] lu(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{…}, Int64}; check::Bool) (repeats 18547 times)
    @ SparseArrays.UMFPACK ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/solvers/umfpack.jl:395
 [13] logabsdet(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}}, Int64})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1676
 [14] logdet(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}}, Int64})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1701
 [15] g(x::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{…}}}}})
    @ Main ./REPL[11]:1
 [16] trace_function(::Type{SparseConnectivityTracer.Dual{…}}, f::typeof(g), x::Vector{Float64})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/pattern.jl:32
 [17] local_hessian_pattern(f::Function, x::Vector{Float64}, ::Type{BitSet}, ::Type{Set{Tuple{Int64, Int64}}})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/pattern.jl:368
 [18] hessian_sparsity(f::Function, x::Vector{Float64}, ::TracerLocalSparsityDetector{BitSet, Set{Tuple{Int64, Int64}}})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/adtypes.jl:112
 [19] top-level scope
    @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.
adrhill commented 4 months ago

Very odd, I can replicate the error down to the same lines of code

StackOverflowError:
Stacktrace:
  [1] Array
    @ ./boot.jl:477 [inlined]
  [2] BitSet
    @ ./bitset.jl:18 [inlined]
  [3] myempty
    @ ~/Developer/SparseConnectivityTracer.jl/src/tracers.jl:5 [inlined]
  [4] myempty
    @ ~/Developer/SparseConnectivityTracer.jl/src/tracers.jl:171 [inlined]
  [5] zero
    @ ~/Developer/SparseConnectivityTracer.jl/src/conversion.jl:16 [inlined]
  [6] float(A::Vector{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}})
    @ Base ./float.jl:1120
  [7] float
    @ /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:983 [inlined]
  [8] lu(A::SparseMatrixCSC{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}, Int64}; check::Bool) (repeats 23607 times)
    @ SparseArrays.UMFPACK /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/solvers/umfpack.jl:395

however, running that float function doesn't error:

julia> T = SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}};

julia> A = rand(T, 3)
3-element Vector{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}}:
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)

julia> float(A)
3-element Vector{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}}:
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)

julia> convert(AbstractArray{typeof(float(zero(T)))}, A) # Base ./float.jl:1120
3-element Vector{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}}:
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
 SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}(
  Gradient: BitSet([]),
  Hessian:  Set{Tuple{Int64, Int64}}()
)
adrhill commented 4 months ago

Updated error in #113:

StackOverflowError:
Stacktrace:
  [1] SparseMatrixCSC{GradientTracer{BitSet}, Int64}(m::Int64, n::Int64, colptr::Vector{Int64}, rowval::Vector{Int64}, nzval::Vector{GradientTracer{BitSet}})
    @ SparseArrays /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:26
  [2] SparseMatrixCSC(m::Int64, n::Int64, colptr::Vector{Int64}, rowval::Vector{Int64}, nzval::Vector{GradientTracer{BitSet}})
    @ SparseArrays /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:44
  [3] float
    @ /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:983 [inlined]
  [4] lu(A::SparseMatrixCSC{GradientTracer{BitSet}, Int64}; check::Bool) (repeats 21640 times)
    @ SparseArrays.UMFPACK /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/SparseArrays/src/solvers/umfpack.jl:395
  [5] logabsdet(A::SparseMatrixCSC{GradientTracer{BitSet}, Int64})
    @ LinearAlgebra /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1676
  [6] logdet(A::SparseMatrixCSC{GradientTracer{BitSet}, Int64})
    @ LinearAlgebra /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1701
  [7] g(x::Vector{GradientTracer{BitSet}})
    @ Main ./REPL[44]:1
  [8] trace_function(::Type{GradientTracer{BitSet}}, f::typeof(g), x::Vector{Float64})
    @ SparseConnectivityTracer ~/Developer/SparseConnectivityTracer.jl/src/pattern.jl:32
  [9] jacobian_pattern(f::Function, x::Vector{Float64}, ::Type{BitSet})
    @ SparseConnectivityTracer ~/Developer/SparseConnectivityTracer.jl/src/pattern.jl:201
gdalle commented 4 months ago

Also usually a StackOverFlow error has thousands of stack frames, not just a handful

adrhill commented 4 months ago

Maybe SparseMatrixCSC recursively tries to cast the GradientTracers in nzval to some other type?

It's surprising that the stacktrace ends there.

adrhill commented 4 months ago

The error occurs in the inner constructor:

"""
    SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrixCSC{Tv,Ti}

Matrix type for storing sparse matrices in the
[Compressed Sparse Column](@ref man-csc) format. The standard way
of constructing SparseMatrixCSC is through the [`sparse`](@ref) function.
See also [`spzeros`](@ref), [`spdiagm`](@ref) and [`sprand`](@ref).
"""
struct SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrixCSC{Tv,Ti}
    m::Int                  # Number of rows
    n::Int                  # Number of columns
    colptr::Vector{Ti}      # Column i is in colptr[i]:(colptr[i+1]-1)
    rowval::Vector{Ti}      # Row indices of stored values
    nzval::Vector{Tv}       # Stored values, typically nonzeros

    function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti}, # <--- L26
                            rowval::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti<:Integer}
        sparse_check_Ti(m, n, Ti)
        _goodbuffers(Int(m), Int(n), colptr, rowval, nzval) ||
            throw(ArgumentError("Invalid buffers for SparseMatrixCSC construction n=$n, colptr=$(summary(colptr)), rowval=$(summary(rowval)), nzval=$(summary(nzval))"))
        new(Int(m), Int(n), colptr, rowval, nzval)
    end
end
gdalle commented 4 months ago

Note that regardless of whether we fix this with local tracing, array-level overloads will still be needed for global tracing cause lots of these linear algebraic functions have control flow.

adrhill commented 4 months ago

lots of these linear algebraic functions have control flow.

Isn't that a good argument for scalar-level tracing?

gdalle commented 4 months ago

Isn't that a good argument for scalar-level tracing?

It is a good argument for local tracing (with dual numbers), but with a global tracer logdet will error. And we need it to work, because global sparsity patterns with linear algebra inside the functions are a perfectly reasonable demand.

If we want to be lazy and give an overestimate, we can do something like:

function LinearAlgebra.logdet(M::AbstractMatrix{<:AbstractTracer})
    return exp(sum(M))  # or any nonlinear fully mixing function
end
adrhill commented 4 months ago

There are two issues here:

First and foremost, when calling logdet with spdiagm, an error occurs when calling the SparseMatrixCSC constructor with tracers.

The second issue, which I think you are alluding to, is calling logdet with non-sparse matrices, in which case local tracers are currently needed:

julia> using SparseConnectivityTracer, SparseArrays, LinearAlgebra

julia> g(x) = logdet(diagm(x)) # <--- NOT spdiagm
g (generic function with 1 method)

julia> hessian_pattern(g, randn(3))
ERROR: TypeError: non-boolean (SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}) used in boolean context
Stacktrace:
  [1] generic_lufact!(A::Matrix{SparseConnectivityTracer.HessianTracer{BitSet, Set{…}}}, pivot::RowMaximum; check::Bool)
    @ LinearAlgebra /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:152
...
adrhill commented 4 months ago

Opened #115 for discussion, renaming this to track the specific SparseMatrixCSC issue.

adrhill commented 3 months ago

The specific issue of logdet erroring on sparse matrices of tracers has been closed in #131.

However, when using sparse matrices of tracers, stack overflows still occur due to lu calling itself:

lu(A::AbstractSparseMatrixCSC; check::Bool = true) = lu(float(A); check = check)

This is due to us "abusing" float to return a tracer. I'm opening a new issue to track this.