I was wondering if there is a way to evaluate the potential on a point that it wasn't trained on. For instance, adapting your example from here:
import torch
from geomloss import SamplesLoss
x = torch.randn(100, 3, requires_grad=True).cuda()
y = torch.randn(100, 3).cuda()
OT_solver = SamplesLoss(loss = "sinkhorn", p = 2, blur = 0.05,
debias = False, potentials = True)
F, G = OT_solver(x, y) # Dual potentials
I would like e.g. to find the value of F in torch.tensor([0., 0., 0.]). I can certainly interpolate between points of x, but is there a more direct way to evaluate the network? I could not find an example of it.
Hello,
And thank you for the awesome library!
I was wondering if there is a way to evaluate the potential on a point that it wasn't trained on. For instance, adapting your example from here:
I would like e.g. to find the value of
F
intorch.tensor([0., 0., 0.])
. I can certainly interpolate between points ofx
, but is there a more direct way to evaluate the network? I could not find an example of it.Thank you so much!