JuliaOptimalTransport / OptimalTransport.jl

Optimal transport algorithms for Julia
MIT License
93 stars 8 forks source link

Question about optimal transport and possibilities of OptimalTransport.jl #157

Open ignace-computing opened 2 years ago

ignace-computing commented 2 years ago

Sorry if this is not the right place to ask this kind of questions.

I am looking for an algorithm that has the following characteristics (see below). Is there any functionality in OptimalTransport.jl that could be used for this?

Many thanks!



(I can of course give more information about the problem is the description is not very clear to you).

davibarreira commented 2 years ago

When you say histogram, do you mean the sum of diracs?

davibarreira commented 2 years ago

Do you have a more formal (mathematical) description of your problem? I don't think there is an easy function that does what you are looking for. It looks like more as an optimization problem, such as Linear Programming.

devmotion commented 2 years ago

The p-Wasserstein distance between two univariate histograms H and H' (ie. sum of Diracs) with the same number n of components is just 1/n ||H - H'||_p^p (see e.g. eq 2.33 in Computational Optimal Transport). Thus based on the p-Wasserstein distance it seems you could (try to) solve the optimization problem

min_{H'} 1/n ||H - H'||_p^p
s.t. 1/n \sum_{i=1}^n H'_i = mean
     1/n \sum_{i=1}^n (H'_i - mean)^2 = variance

Of course, the factor 1/n in the objective function could be ignored.

Edit: An important assumption that I forgot to mention is of course that both H and H' are sorted.

devmotion commented 2 years ago

It seems this optimization problem can be solved analytically at least for p = 2. Using Langrange multipliers it seems (better check it carefully, I just did a quick sketch) one obtains the possible optima H' = m - (H - mean(H)) / std(H) * s and H' = m + (H - mean(H)) / std(H) * s where m is the desired mean and s = sqrt(variance) is the desired standard deviation. Clearly, both satisfy mean(H') = m and var(H') = s^2 = variance. However, only H' = m + (H - mean(H)) / std(H) * s is sorted in ascending order (H is sorted in ascending order!), the other possible solution is sorted in reverse order. Hence H' = m + (H - mean(H)) / std(H) * s seems to be the unique solution of the optimization problem.

ignace-computing commented 2 years ago

Dear @davibarreira end @devmotion, thank you for having answered my question! And happy new year of course, meanwhile.

I think that what @devmotion suggests is perfectly correct, it looks like a good idea to try to minimize the p-Wasserstein distance, and your solution to the optimization problem also intuitively makes sense. However, it might not work in the setting I am envisaging. Therefore let me clarify:

We have a grid x, on on that grid we define a (discretized) density function u. More specifically, in each grid cell x_i, u_iis the average value of u (integrated over the grid cell). So let the pair (x,u) define a histogram H. The mean value of H is given by mu = u’*x and the variance equals var = u’*(x – mu).^2

Now we wish to obtain a new histogram H’, thus a pair (x’,u’), with the desired properties (namely, mean and variance (M,V)). It is also desired that x’=x (this makes it different from what you suggested.)

ignace-computing commented 2 years ago

@davibarreira Is it now more clear what I meant with histogram? It's like, you specify a number of bins and then you evaluate a probability density function on that grid.

davibarreira commented 2 years ago

Yes, it's clear. But the package does not have a solution to your problem. Is your grid 1D? If yes, you could use the solution @devmotion proposed. Otherwise, I don't think there is an easy answer.

ignace-computing commented 2 years ago

Yeah, so just for completeness let me add the form of the optimization problem that I am wishing to solve. x and H are vectors of the same dimension carrying the grid an the probability density thereon, respectively. The ' symbol denotes transpose (for clarity never applied to H or H')

min_{H'} ||H - H'||_p^p
s.t. x'*H' = mean
     (x-mean)'.^2*H' = variance