jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
570 stars 57 forks source link

Evaluating the Brenier potential in a new point #50

Open fcocquemas opened 3 years ago

fcocquemas commented 3 years ago

Hello,

And thank you for the awesome library!

I was wondering if there is a way to evaluate the potential on a point that it wasn't trained on. For instance, adapting your example from here:

import torch
from geomloss import SamplesLoss

x = torch.randn(100, 3, requires_grad=True).cuda()
y = torch.randn(100, 3).cuda()
OT_solver = SamplesLoss(loss = "sinkhorn", p = 2, blur = 0.05,
                        debias = False, potentials = True)
F, G = OT_solver(x, y)  # Dual potentials

I would like e.g. to find the value of F in torch.tensor([0., 0., 0.]). I can certainly interpolate between points of x, but is there a more direct way to evaluate the network? I could not find an example of it.

Thank you so much!