SciML / ADTypes.jl

Repository for automatic differentiation backend types
https://sciml.github.io/ADTypes.jl/
MIT License
36 stars 11 forks source link

Change handling of sparse backends #38

Closed gdalle closed 3 months ago

gdalle commented 3 months ago

Problems:

Suggested solutions (instead of #37):

struct AutoSparse{B<:AbstractADType,S,C} <: AbstractADType
    backend::B
    sparsity_detector::S
    coloring_algorithm::C
end

To preserve backwards-compatibility we can try to define shortcuts like the following, to be removed in v0.3.0:

const AutoSparseBackend = AutoSparse{<:AutoBackend}

function AutoSparseBackend(args...; sparsity_detector, coloring_algorithm, kwargs...)
    backend = AutoBackend(args...; kwargs...)
    return AutoSparse(backend, sparsity_detector, coloring_algorithm)
end

However I'm not 100% sure we can keep this from being breaking.

Vaibhavdixit02 commented 3 months ago

Definitely agree that the current way of AutoSparse* is pretty tedious and not modular.

I'd keep the name something else since AutoSparseBackend would make it appear like a choice of AD backend so just Sparse(args...) seems nicer?

gdalle commented 3 months ago

I thought about that but Sparse is a bit generic, hence the addition of Auto. Besides, we would actually have

AutoSparse <: AbstractADType

so it seems coherent.

Vaibhavdixit02 commented 3 months ago

I see, that seems reasonable

gdalle commented 3 months ago

@avik-pal @ChrisRackauckas any thoughts? It's mostly SciML folks who are using this feature so far

Vaibhavdixit02 commented 3 months ago

One more option AutoSparsifyBackend for the constructor? 😅

gdalle commented 3 months ago

just Sparse(args...) seems nicer?

The killer argument is that we will export that name, and we cannot decently export something named Sparse

One more option AutoSparsifyBackend for the constructor? 😅

What do you mean?

ChrisRackauckas commented 3 months ago

AutoSparse <: AbstractADType is a very good middle ground IMO.

And there's an interconnected piece here. Though it would take quite a bit of work to actually make SparseDiffTools work with all AD backends in its current form. But with DI, DI could take over what is the front end of SparseDiffTools and redeisgn the color Jacobian functions to not be forwarddiff exclusive but instead call the DI high level chunked modes, and then all backends that support that would then be color differentiation compatible. Ultimately that's the best solution here, and no one has had the time to pull it off but @gdalle seems like you have the time to finally put the polishing steps on this and see it, and indeed that's the right thing to do here.

gdalle commented 3 months ago

My vision is that there are 3 phases to sparse AD, and they can (should) be separated:

The first step is this ADTypes revamp

ChrisRackauckas commented 3 months ago

Yes, that vision is correct. We've just been scraped for hands so... welcome aboard!

Sparsity pattern detection: currently done by Symbolics, possibly 50x faster with SparseConnectivityTracer (see benchmarks at https://github.com/adrhill/SparseConnectivityTracer.jl/pull/4)

Nice! Yes, this has been a bit of a pain for us. This should probably be in JuliaDiff?

Coloring: currently done by SparseDiffTools

Yup and that's ultimately all it should be.

Jacobian evaluation: currently done by SparseDiffTools, should be done by DifferentiationInterface in chunked mode to be backend-agnostic

Agreed.

gdalle commented 3 months ago

Ultimately everything should be in JuliaDiff, but @adrhill and I are iterating very fast so it's easier to keep it local. I fully intend to transfer DI to JuliaDiff eventually

avik-pal commented 3 months ago

Sparsity pattern detection: currently done by Symbolics, possibly 50x faster with SparseConnectivityTracer (see benchmarks at https://github.com/adrhill/SparseConnectivityTracer.jl/pull/4)

This is based on input propagation and won't work (give incorrect pattern) for control flow or branching that affects sparsity patterns, right? I am not familiar with how symbolics works, but from Shashi's paper, I thought that is one of the challenges it is addressing.

Regardless this seems like a more principled way to alteast replace the approx sparse methods in sparsedifftools.

ChrisRackauckas commented 3 months ago

This is based on input propagation and won't work (give incorrect pattern) for control flow or branching that affects sparsity patterns, right? I am not familiar with how symbolics works, but from Shashi's paper, I thought that is one of the challenges it is addressing.

SparsityDetection.jl had extra things for the union sparsity, but that had maintenance troubles of Cassette and those got dropped. The simplified version then was to Symbolics trace, which then became what we had today. A simpler tracer is following that evolution in a fine way.

Yes, it is limited, but it's at least correct and errors when not possible. Not ideal but tends to work in most use cases on model code, which is good enough.

gdalle commented 3 months ago

This is based on input propagation and won't work (give incorrect pattern) for control flow or branching that affects sparsity patterns, right?

Yes indeed, this assumes the same control flow throughout subsequent function executions. In that way it is coherent with DifferentiationInterface semantics (otherwise we couldn't e.g. build a ReverseDiff tape during the preparation step).

that had maintenance troubles of Cassette and those got dropped.

It was kind of our original idea : redo the "Sparsity Programming" paper with a much more lightweight approach. Maybe not as powerful, but a hell of a lot faster

adrhill commented 3 months ago

This is based on input propagation and won't work (give incorrect pattern) for control flow or branching that affects sparsity patterns, right?

We could in theory provide a second dual number tracer that uses primal values for control flow / branches.

Currently, SparseConnectivityTracer returns input-output connectivities, which are a conservative estimate of the Jacobian sparsity. We could be a bit smarter in our operator overloading for functions like round that return zero-derivatives.

gdalle commented 3 months ago

I don't think we can handle control flow, there is no way to do this with operator overloading alone.

gdalle commented 3 months ago

That's also why StochasticAD for instance has to rely on workarounds to support if/else

https://gaurav-arya.github.io/StochasticAD.jl/dev/limitations.html

ChrisRackauckas commented 3 months ago

Yes, it requires global not local information, so it's simply not possible to do with operator overloading.