adrhill / SparseConnectivityTracer.jl

Fast operator-overloading Jacobian & Hessian sparsity detection.
MIT License
26 stars 2 forks source link

Specialize array overloads #192

Open gdalle opened 1 week ago

gdalle commented 1 week ago

At the moment, src/overloads/arrays.jl contains linear algebra functions defined on AbstractMatrix{<:Tracer}. Unfortunately, this dispatch will most likely never be hit, because each array type (Matrix, Diagonal, SparseMatrixCSC, etc.) has its own implementation of things like * or det. And f(::ConcreteMatrixType{<:Real}) will always take precedence over f(::AbstractMatrix{<:Tracer}).

My suggestion: only define these overloads for the basic Array types (Vector / Matrix). That way, we make sure that our methods are actually hit when we want them to be. And we also provide an actionable solution for people who are stuck by failing linear algebra methods: put everything inside a normal Array and you should be fine.

gdalle commented 1 week ago

Related issues which are hopeless on abstract arrays:

adrhill commented 1 week ago

To add injury to their limited use, array overloads require a huge amount of effort to write and test.

An alternative mentioned in #144 is to wrap all array overloads in a function that leverages meta-programming to generate methods on XYZArray{<:AbstractTracer}. However, this approach is limited to single-argument function on arrays. As mentioned in #133, multi-argument functions on arrays are even more of a pain. Quoting myself:

Matrix multiplication is already complex enough for simple Matrix and Vector.

It requires methods for:

  1. Matrix of tracers * Matrix of tracers
  2. Matrix of tracers * Vector of tracers
  3. Vector transposed of tracers * Matrix of tracers
  4. Vector transposed of tracers * Vector of tracers
  5. Matrix of reals * Matrix of tracers
  6. Matrix of reals * Vector of tracers
  7. Vector transposed of reals * Matrix of tracers
  8. Vector transposed of reals * Vector of tracers
  9. Matrix of tracers * Matrix of reals
  10. Matrix of tracers * Vector of reals
  11. Vector transposed of tracers * Matrix of reals
  12. Vector transposed of tracers * Vector of reals
adrhill commented 1 week ago

My suggestion: only define these overloads for the basic Array types (Vector / Matrix).

Arguably, a codebase that has use for SCT is written in a sparse manner and will rarely use the non-sparse Vector and Matrix types. Instead, it's more likely to perform scalar operations or use types from SparseArrays.

gdalle commented 1 week ago

Arguably, a codebase that has use for SCT is written in a sparse manner and will rarely use the non-sparse Vector and Matrix types. Instead, it's more likely to perform scalar operations or use types from SparseArrays.

Not true. SCT materializes the sparsity pattern of the Jacobian, but inside the code said Jacobian never needs to exist at all. Typically, the Brusselator gives rise to sparse Jacobians without ever creating a SparseMatrixCSC. Same for the Conv layer.

gdalle commented 1 week ago

As mentioned in, multi-argument functions on arrays are even more of a pain.

Yeah we definitely don't wanna go down the ReverseDiff road of generating truckloads of methods for combinations of lists of types. It's extremely brittle and has even broken the tests of the package for a good long while

adrhill commented 1 week ago

Typically, the Brusselator gives rise to sparse Jacobians without ever creating a SparseMatrixCSC.

Sure, but the Brusselator falls into my first category of functions:

Instead, it's more likely to perform scalar operations

And I'm not convinced Conv layers are a common use-case for sparsity detection. I put them in the README because they nicely demonstrate how generic our code is. In fact, I'm not sure the NNlib implementation of generic convolution uses Matrix multiplication either. I'm pretty sure that example predates #131. I think it also falls in the category of functions using scalar operations.

gdalle commented 1 week ago

Sorry I had missed the "perform scalar operations" part.