ericphanson / UnbalancedOptimalTransport.jl

Sinkhorn divergences for measures of unequal mass
Other
14 stars 1 forks source link

Provide precomputed cost matrix #4

Closed baggepinnen closed 4 years ago

baggepinnen commented 4 years ago

Thanks for this great package :)

I was wondering if it would be possible to support a precomputed cost matrix instead of the cost function C. As far as I can tell, the call to C on the line tmp_f[i] = a.log_density[i] + (f[i] - C(x[i], y[j])) / ϵ computes the same thing each iteration which is wasteful.

I would provide a PR, but this would change the interface to the function somewhat as this would not require supplying the support points of the measures, just the weights and the cost matrix.

EDIT: Maybe the user could be free to supply either a function C or a matrix, and whatever the user passes in would be sent through


handle_C(C::AbstractMatrix,a,b) = C
function handle_C(C,a,b)
    x = a.set
    y = b.set
    (x === nothing || x === nothing) && throw(ArgumentError("If a cost function is provided, the support points of the measures are required."))
    [C(x, y) for x in x, y in y]
end

and then the algorithms internally only make use of the precomputed matrix. This would be made even smoother if the DiscreteMeasure could accept nothing for the support points as those would not be used if the matrix is supplied.

I'll provide a PR for your consideration

ericphanson commented 4 years ago

Ah, good point about the inefficiency. Yeah, a precomputed matrix sounds like a good way to go since we need to compute all pairwise costs in general and if we want to avoid recomputing them then we need to store them somewhere anyway.

I don’t mind breaking changes at this stage; any improvements to the API are welcome. For example, we could make the cost a positional argument (eg after the divergence) and then add a method accepting a matrix (and support cost functions by creating a matrix and passing it on).

One thing is although we don’t need the support if we have the cost matrix, the DiscreteMeasures also carry the dual potentials and caches so that the algorithm can be allocation-free. Many of the functions call the Sinkhorn algorithm to populate the dual potentials and then use them to calculate eg the divergence. So we could still use the same DiscreteMeasures but let the set field be nothing, or we could have the user explicitly supply dual potential vectors to populate instead, etc.

I think the current API is not the most generic though because I create Vectors whereas maybe the user wants CuArrays etc. However it is at least simple. Open to any ideas :).

PRs welcome, though I can try to get to it soon too.

ericphanson commented 4 years ago

Oops just saw your edit. Looks like we had somewhat the same idea for what to do! Looking forward to the PR.