adrhill / SparseConnectivityTracer.jl

Fast operator-overloading Jacobian & Hessian sparsity detection.
MIT License
26 stars 2 forks source link

Handle StaticArrays? #144

Closed gdalle closed 1 week ago

gdalle commented 1 month ago
julia> using StaticArrays, SparseConnectivityTracer

julia> function f(x)
           y = SVector(x[1], x[2])
           return norm(y)
       end
f (generic function with 1 method)

julia> jacobian_sparsity(f, ones(2), TracerSparsityDetector())
ERROR: TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}) used in boolean context
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/StaticArrays/MSJcA/src/linalg.jl:278 [inlined]
 [2] _norm
   @ ~/.julia/packages/StaticArrays/MSJcA/src/linalg.jl:266 [inlined]
 [3] norm
   @ ~/.julia/packages/StaticArrays/MSJcA/src/linalg.jl:265 [inlined]
 [4] f(x::Vector{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}})
   @ Main ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:56
 [5] trace_function(::Type{SparseConnectivityTracer.GradientTracer{…}}, f::typeof(f), x::Vector{Float64})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/pVwOW/src/interface.jl:33
 [6] _jacobian_sparsity(f::Function, x::Vector{Float64}, ::Type{SparseConnectivityTracer.GradientTracer{…}})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/pVwOW/src/interface.jl:59
 [7] jacobian_sparsity(f::Function, x::Vector{…}, ::TracerSparsityDetector{…})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/pVwOW/src/adtypes.jl:49
 [8] top-level scope
   @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:59
Some type information was truncated. Use `show(err)` to see complete types.

See https://github.com/JuliaSmoothOptimizers/ADNLPModels.jl/issues/247#issuecomment-2260558148

adrhill commented 1 month ago

Your linked comment points out the issue: https://github.com/adrhill/SparseConnectivityTracer.jl/blob/ec44afc4b84cb1b6b7f234870776b4658a4987ac/src/overloads/arrays.jl#L72

We overload norm(A::AbstractArray{<:AbstractTracer}), which is less specific than StaticArrays' norm(A:: StaticArray) (source).

This is a general issue with array overloads since LinearAlgebra uses array-level types as traits for dispatch (e.g. Symmetric, Hermitian, ...). See also #133.

We could solve this problem by sticking our array overloads in a big loop over different array types.

adrhill commented 1 month ago

I guess we could start by supporting types form StaticArrays as well as common LinearAlgebra matrix types?

gdalle commented 1 month ago

Yeah but where does it stop?

gdalle commented 1 month ago

Also I don't want StaticArrays in the main dependencies so we might have to make the loop parametric and handle static types in a package extension.

adrhill commented 1 week ago

Closing in favor of the more general discussion in #192.