jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
582 stars 57 forks source link

propagate gradient for weights #14

Open AdrienCorenflos opened 4 years ago

AdrienCorenflos commented 4 years ago

Hi,

This line detaches the gradient for the weights associated with the locations for (I believe) no real reason as it serves the same purpose as the cost matrix in the fixed point iteration.

It's probably worth detaching a_x, a_y, b_x and b_y only.

https://github.com/jeanfeydy/geomloss/blob/9188d051302cbc4a9a7c2224cfab8ed4a31b23b7/geomloss/sinkhorn_divergence.py#L219

jeanfeydy commented 4 years ago

Hi @AdrienCorenflos ,

Thanks for you interest! I may have made a mistake in my derivations, but this line is, I think, essential to get the correct gradient with respect to the x_i's. As discussed in this tutorial (+ more details in my PhD thesis, available in January), getting the correct expression for:

v(x_i) = - (1/αi) ∂{x_i} S(α,β) = - ∇ F(x_i)

is essential for many (most) applications. Implemented this way, the "detach" trick allows us to get the correct expressions for all derivatives with respect to (α_i, x_i; β_j, y_j) without wasting time with a backpropagation through the loop, or writing by hand an explicit derivative. Note that I don't guarantee the correctness for order 2 gradients though! (If it is of use to somebody, I could think about it... but I don't expect it to be a real use-case for now.) What do you think?

Best regards,

Jean

AdrienCorenflos commented 4 years ago

Hi @jeanfeydy

Thanks for the quick response. Maybe I should have been a bit more specific :)

I am interested in the gradient of the coupling matrix w.r.t. the inputs, not necessarily (for this application) in the gradient of the loss function itself. This means that I'm interested in the gradient of the dual potentials w.r.t. w_x and X (see below).

I think the most straightforward way of getting them would be through moving the .detach calls around but not completely sure. I've monkey-patched the code in sinkhorn_divergences with the change and it seems to do the trick but I've not verified the maths...

import geomloss
import torch

N = 1000

X = torch.rand((N, 1), requires_grad=True) * 4 - 2
Y = torch.rand((N, 1)) * 6 - 3

theta = torch.tensor(0.5, requires_grad=True)
w_x = torch.exp(-X.squeeze()**2/theta)
w_y = torch.full_like(w_x.clone(), 1/N)

loss = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05, diameter=1.,
                            scaling=0.9, debias=False, potentials=True)

alpha, beta = loss(w_x, X, w_y, Y)

print(torch.autograd.grad(alpha.mean(), [w_x])) #raises that the gradient is None
print(torch.autograd.grad(alpha.mean(), [X])) #all good
jeanfeydy commented 4 years ago

Hi @AdrienCorenflos ,

I see. Indeed, what you're looking for here is precisely the second derivative of the Sinkhorn loss: the values of the potential "f_i" and "g_j" (potentials = True) are no one but the derivatives of "OT_ε( α , β )" with respect to the weights α_i and β_j. Computing an explicit formula by hand would be do-able, but not completely straightforward: you'd have to linearize the "optimality" equation satisfied by the dual potentials and solve with respect to the variations (δf,δg) in function of (δα,δx,δβ,δy). With the notations of the GeomLoss papers and code:

f_i = - ε log ∑_j β_j exp [ g_j - C(x_i,y_j) ] / ε and g_j = - ε log ∑_i α_i exp [ f_i - C(x_i,y_j) ] / ε

to be linearized wrt. f, g, α, β, x and y. This is do-able, but requires more than one page of computations.

Fortunately, thanks to the nature of the fixed point iterations, your detach implementation should indeed do a reasonable job - as long as we can assume that the algorithm has converged to a solution of the OT_ε problem. I would have thought that using:

a_y, b_x = λ * softmin(ε, C_yx, α_log + (b_x/ε).detach() ), \
           λ * softmin(ε, C_xy, β_log + (a_y/ε).detach() )

would be a good first guess, but I don't have much experience with this. Needless to say, you should check your derivatives numerically using some kind of gradcheck routine. Feel free to tell me what ends up working for you: right now, I don't have much time to do the derivations myself, but I will definitely try to work out a way of making these derivatives "work" without messing up with the other use cases. You've convinced me that it could indeed be relevant to people in 2020 :-)

Best regards,

Jean