jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
582 stars 57 forks source link

ImagesLoss #10

Open agostbiro opened 4 years ago

agostbiro commented 4 years ago

Hi,

I'm very excited about this work! I was wondering if ImagesLoss with Sinkhorn divergence would be usable (when it's ready) as a reconstruction loss for an image autoencoder for representation learning?

Pseudocode to code to illustrate what I was thinking about:

encoder = ConvNet()
decoder = DeconvNet()
optimizer = Optimizer(chain(encoder.parameters(), decoder.parameters()))
loss_fn = ImagesLoss(loss="sinkhorn", p=2, blur=.05)

for images in dataset:
  optimizer.zero_grad()
  encoding = encoder(images)
  image_preds = decoder(encoding)
  loss = loss_fn(images, image_preds)
  loss.backward()
  optimizer.step()

Thanks, Agost

jeanfeydy commented 4 years ago

Hi @abiro ,

Sure! However, please keep in mind that OT theory is always concerned with measures, distributions of mass. In your example, loss_fn(images, images_preds) would return a vector of Wasserstein-like distances between each pair of original and encoded-decoded image, all of them being understood as density maps that represent measures in the rectangular image domain. (As a side note, before doing a .backward(), you would have to add a loss = loss.sum().) Consequently, ImagesLoss and VolumesLoss will be best suited to the processing of segmentation masks in medical imaging and computer vision.

Note that using OT on raw image data is asking for trouble: measure theory handles "black" and "white" as "absence" and "presence" of mass, which is rarely what researchers want to do with natural images. Fortunately, as discussed for instance in the color transfer tutorial, you can perfectly represent an image as a point cloud in some arbitrary feature space, and use SamplesLoss to define a relevant processing. In your example, a simple way of defining a transport-based loss functions between two "natural" images is to represent them as point clouds (x, y, r, g, b) in R^5 and feed them to SamplesLoss. Of course, you could then tune the ratio between the "geometric features" (x, y) and the color coordinates (r, g, b) or use a perceptual color space to improve your results.

Best regards,

Jean

agostbiro commented 4 years ago

Thank you so much for your detailed answer and great work on this package!

Best, Agost

JunMa11 commented 4 years ago

Dear @jeanfeydy ,

Thanks for the excellent work.

OT theory is always concerned with measures, distributions of mass. Consequently, ImagesLoss and VolumesLoss will be best suited to the processing of segmentation masks in medical imaging and computer vision.

For commonly used Dice loss in segmentation tasks, the inputs are softmax results and ground truth. e.g. loss_seg_dice = dice_loss(softmax, ground truth)

def dice_loss(score, target):
    """
    binary segmentation
    outputs = net(volume_batch)
    outputs_soft = F.softmax(outputs, dim=1)
    loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
    """
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(score * score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss

What are the inputs of ImagesLoss and VolumesLoss in segmentation task? Are the inputs the same as dice_loss?

loss_fn = ImagesLoss(loss="sinkhorn", p=2, blur=.05)
loss = loss_fn(?, ?)
JunMa11 commented 4 years ago

Moreover, it seems Wasserstein distance has computational problems.

how big an image can be calculated by ImagesLoss and VolumesLoss in practice? (In general, the size of 3D medical CT images is (512, 512, 100+) )

ww

source

jeanfeydy commented 4 years ago

Hi @JunMa11 ,

Indeed, that will be the interface of ImagesLoss and VolumesLoss: (soft or sharp) segmentation masks as input, scalar loss value as output. As for performances: there's been massive progress in the literature over the last two years, and GeomLoss is all about making them easily accessible to the community. Using nothing but the SamplesLoss layer on these images encoded as point clouds, you could already widely outperform the runtimes above. Simply turn your segmentation maps into (weights, positions) arrays with:

def to_pointcloud(A, dtype=torch.cuda.FloatTensor):
    "A is a density map encoded as a NumPy array of shape (w,h)."

    A[A<=0] = 1e-8  # We'd rather not compute log(-1) later on...
    a_i = A.ravel() / A.sum()  # Normalized vector of weights

    # Let's map the image grid to the unit square:
    x, y = np.meshgrid( np.linspace(0, 1, A.shape[0] + 1)[:-1], 
                        np.linspace(0, 1, A.shape[1] + 1)[:-1] )
    x += .5 / A.shape[0] ; y += .5 / A.shape[1]

    x_i = np.vstack( (x.ravel(), y.ravel()) ).T  # Point cloud

    return torch.from_numpy(a_i).type(dtype), \
           torch.from_numpy(x_i).contiguous().type(dtype)

and you'll be good to go. (The key advantage of GeomLoss solvers over standard linear programming routines is that they fully leverage the structure of the cost matrix: when C[i,j] = (1/p) * |x_i - y_j|^p, the OT problem is much closer to a simple sorting problem than to a generic Kantorovitch-style assignment problem. Softwares such as Gurobi are designed for operations research with arbitrary Cost matrices (a very hard problem), not geometry (which is generally much simpler).)

Going further, ImagesLoss and VolumesLoss will take advantage of the fact that when you work with measures sampled on a grid (instead of a generic point cloud), non-trivial specific schemes can bring massive speed-ups: think, for instance, of the Gaussian blur that can be implemented with separable filters in O(N) instead of O(N^2), where N is the number of pixels.

I've been fairly busy writing my PhD thesis over the last few months, but am now slowly starting to get back to work on GeomLoss. To get maximum performances with grid images, I am currently learning how to implement super-efficient (log-)convolutions with TVM... But that's fairly intricate, and the whole process will probably take me a few weeks/months :-)

Best regards, Jean

JunMa11 commented 4 years ago

Hi @jeanfeydy ,

Thanks for your quick and detailed reply very much.

I'm confused when comparing the above to_pointcloud function and RGB_cloud function in the color transfer demo.

def RGB_cloud(fname, sampling=1, dtype=torch.FloatTensor) :
    A = load_image(fname)
    A = A[::sampling, ::sampling, :]
    return torch.from_numpy(A).type(dtype).view(-1,3)

If inputting a commonly used RGB 2D/CT 3D image with shape=(256,256,3) or (512,512,100), respectively, the expected point cloud outputs are (x, y, r, g, b), (x, y, z, CT_value), right?

However, for the function RGB_cloud, the output shape is (256256, 3), while the output shape of to_pointcloud is (256256, 2). I'm confused about the differences.

Specifically, I want to use the Wasserstein distance in the following two settings.

1. compute the Wasserstein distance between the label and softmax outputs

trainloader = DataLoader(data_train, batch_size, shuffle=True,  num_workers=4)
net = UNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
net = net.cuda()
net.train()

loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.01)
for epoch_num in range(max_epoch):
    for i_batch, sampled_batch in enumerate(trainloader):
        volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        # outputs.shape = (b, c, x, y, z), label_shape = (b, 1, x, y, z)
        outputs = net(volume_batch)
        outputs_soft = F.softmax(outputs, dim=1)
        # how should I pass the outputs_soft and label_batch into the loss_fn?
        loss = loss_fn(outputs_soft ?, label_batch?)

Question: how should I pass the outputs_soft and label_batch into the loss_fn?

outputs_soft.shape = (b, c, x, y, z); label_distance_map.shape= (b, 1, x, y, z); or (b, c, x, y, z)- one-hot encoding

2. compute the Wasserstein distance between two distance maps

L1/L2 norm and KL-divergence is commonly used in this setting. Now, I want to try Wasserstein distance.

loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.01)
for epoch_num in range(max_epoch):
    for i_batch, sampled_batch in enumerate(trainloader):
        volume_batch, label_distance_map = sampled_batch['image'], sampled_batch['label']
        # outputs.shape = (b, c, x, y, z), label_shape = (b, c, x, y, z)
        # how should I pass the outputs_soft and label_batch into the loss_fn?
        loss = loss_fn(outputs ?, label_distance_map?)

Question: how should I pass the outputs and label_distance_map into the loss_fn?

outputs.shape = label_distance_map.shape= (b, c, x, y, z);

Note: here the outputs is the logits rather than the softmax outputs. label_distance_map is distance transform of the ground truth by scipy.

Looking forward to your reply at your convenience.

Best regards, Jun