adrhill / SparseConnectivityTracer.jl

Fast operator-overloading Jacobian & Hessian sparsity detection.
MIT License
27 stars 2 forks source link

Behavior of zeros #172

Closed SouthEndMusic closed 3 months ago

SouthEndMusic commented 3 months ago

This might be a quite general problem but I'd like to hear your thoughts.

using SparseConnectivityTracer:
    TracerLocalSparsityDetector, jacobian_sparsity

N = 5

function f!(du, u)
    du .= sign.(u) .* u
end

u = rand(N)
du = zero(u)
jacobian_sparsity(f!, du, u, TracerLocalSparsityDetector())

In this example, the resulting Jacobian has no non-zeros purely because of the choice of u. Here it is quite easy to solve, but in applications it is non-trivial to construct a u that avoids this problem. How would you get around this problem?

SouthEndMusic commented 3 months ago

I now did this:

function Base.sign(x::D) where {P <: Real, T <: GradientTracer, D <: Dual{P, T}}
    p = primal(x)
    p = one(p) > 0 ? p : -one(p)
    Dual(p, tracer(x))
end
SouthEndMusic commented 3 months ago

With 3 simple custom overloads I made SparseConnectivityTracer work for my application 😃

adrhill commented 3 months ago

In this example, the resulting Jacobian has no non-zeros purely because of the choice of u.

This is the intended behavior of TracerLocalSparsityDetector. Quoting the README:

TracerSparsityDetector returns conservative sparsity patterns over the entire input domain of x. It is not compatible with functions that require information about the primal values of a computation (e.g. iszero, >, ==).

To compute a less conservative sparsity pattern at an input point x, use TracerLocalSparsityDetector instead. Note that patterns computed with TracerLocalSparsityDetector depend on the input x.

Multiplication by zero "removes" all derivative information, that's why your local Jacobian pattern ends up being zero.

Using TracerSparsityDetector will give you the right pattern:


julia> jacobian_sparsity(f!, du, u, TracerSparsityDetector())
5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries:
 1  â‹…  â‹…  â‹…  â‹…
 â‹…  1  â‹…  â‹…  â‹…
 â‹…  â‹…  1  â‹…  â‹…
 â‹…  â‹…  â‹…  1  â‹…
 â‹…  â‹…  â‹…  â‹…  1

Why are things this way?

I can image that some users might be wondering "Why can't we have a non-local Dual tracer that supports comparisons?". This question is something I intend on document in depth (#143) as it is an intentional design choice. Let me motivate this decision by a simple example function:

x = [1, 2]

function foo(x)
    if x[1] > x[2]
        return x[1]
    else 
        return x[2]
    end
end

The desired global Jacobian sparsity pattern over the entire input domain $x \in \mathbb{R}^2$ is [1 1]. Two local sparsity patterns are possible: [1 0] for $\{x | x_1 > x_2\}$, [0 1] for $\{x | x_1 \le x_2\}$.

The local sparsity patterns are easy to compute using operator overloading by using dual numbers which contain primal values on which we can evaluate the comparison >.

The global sparsity pattern is impossible to compute when code branches with an if-else condition. We can only ever hit one branch! If we made > on comparisons return true or false, we'd get the local patterns[1 0]and[0 1]` respectively. But SCT's goal is to guarantee conservative sparsity patterns, which means that "false positives" (ones) are acceptable, but "false negatives" (zeros) are not. In my our opinion, the right thing to do here is to throw an error.

Since Dual tracers could hit branches, we can only ever guarantee that they return local sparsity patterns! This is the crux of the problem. But since they only return local patterns, we can at least be smart about and return the most sparse local patterns possible, which is what we do in the case of multiplication. Multiplication with a Dual will evaluate whether the function has zero-derivatives at the given input points:

https://github.com/adrhill/SparseConnectivityTracer.jl/blob/8251cb3f9558671fddea913b415f6becbf2bd9e0/src/operators.jl#L264-L266

And that's how you end up with a local Jacobian sparsity pattern of zeros.

adrhill commented 3 months ago

I now did this:

function Base.sign(x::D) where {P <: Real, T <: GradientTracer, D <: Dual{P, T}}
    p = primal(x)
    p = one(p) > 0 ? p : -one(p)
    Dual(p, tracer(x))
end

Unfortunately, this code is incorrect: the derivatives of the sign function are globally zero. You therefore can't just return the original tracer(x).

adrhill commented 3 months ago

I'll close this issue since the output is the intended behavior. We're on the finishing straight of writing our SCT paper. Once this is done, I'll prioritize writing documentation (#143, #163).