masadcv / FastGeodis

Fast Implementation of Generalised Geodesic Distance Transform for CPU (OpenMP) and GPU (CUDA)
https://fastgeodis.readthedocs.io
BSD 3-Clause "New" or "Revised" License
90 stars 14 forks source link

[BUG] FastGeodis.generalised_geodesic2d is slower on CPU than sp.ndimage.distance_transform_cdt #57

Closed londumas closed 8 months ago

londumas commented 8 months ago

Thank you for the lib, the GPU acceleration saved my day!

When looking at the time performances of FastGeodis.generalised_geodesic2d versus the corresponding Euclidean sp.ndimage.distance_transform_cdt:

This might be expected, but is a little bit conterintuitive, specially looking at the plots from the README. However, I totally understand that this latter plot is on Non-Euclidean space

Capture

import torch
import time
from functools import partial
import scipy as sp
import FastGeodis
import matplotlib.pyplot as plt
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

generalised_geodesic2d = partial(
    FastGeodis.generalised_geodesic2d, v=1.0e10, lamb=0.0, iter=1
)
distance_transform_cdt = sp.ndimage.distance_transform_cdt

nb_warmup = 5
nb_repeates = 10
sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]  # , 8194]
times = {
    "FastGeodis-generalised_geodesic2d-gpu": [],
    "FastGeodis-generalised_geodesic2d-cpu": [],
    "scipy-distance_transform_cdt-cpu": [],
}

# Compare values
for size in sizes:
    # Setup the image
    image = torch.ones((1, 1, size, size))
    image[0, 0, int(size // 2) : int(size // 2) + 1, :] = 0.0

    ## For GPU-FastGeodis
    image = image.to(device)
    with torch.no_grad():
        euclidean_dist = generalised_geodesic2d(
            image=image,
            softmask=image,
        )
    image = image[0, 0]
    image = image.detach().cpu().numpy()
    euclidean_dist = euclidean_dist.detach().cpu().numpy()[0, 0]

    # For scipy
    np_euclidean_dist = distance_transform_cdt(image)

    # Plot
    # plt.imshow(euclidean_dist)
    # plt.grid()
    # plt.show()
    # plt.imshow(np_euclidean_dist)
    # plt.grid()
    # plt.show()

    print(
        (euclidean_dist - np_euclidean_dist).min(),
        (euclidean_dist - np_euclidean_dist).max(),
    )

# Compare time
for size in sizes:
    # Setup the image
    image = torch.ones((1, 1, size, size))
    image[0, 0, int(size // 2) : int(size // 2) + 1, :] = 0.0

    ## For GPU-FastGeodis
    image = image.to(device)
    with torch.no_grad():
        for _ in range(nb_warmup):
            euclidean_dist = generalised_geodesic2d(
                image=image,
                softmask=image,
            )
        start = time.time()
        for _ in range(nb_repeates):
            euclidean_dist = generalised_geodesic2d(
                image=image,
                softmask=image,
            )
        end = time.time()
    times["FastGeodis-generalised_geodesic2d-gpu"] += [end - start]

    ## For GPU-FastGeodis
    image = image.detach().cpu()
    with torch.no_grad():
        for _ in range(nb_warmup):
            euclidean_dist = generalised_geodesic2d(
                image=image,
                softmask=image,
            )
        start = time.time()
        for _ in range(nb_repeates):
            euclidean_dist = generalised_geodesic2d(
                image=image,
                softmask=image,
            )
        end = time.time()
    times["FastGeodis-generalised_geodesic2d-cpu"] += [end - start]

    # For scipy
    np_image = image.detach().cpu().numpy()[0, 0]
    for _ in range(nb_warmup):
        euclidean_dist = distance_transform_cdt(np_image)
    start = time.time()
    for _ in range(nb_repeates):
        euclidean_dist = distance_transform_cdt(np_image)
    end = time.time()
    times["scipy-distance_transform_cdt-cpu"] += [end - start]

sizes = np.array(sizes)
for k in times.keys():
    times[k] = np.array(times[k])
    print(k, times[k])

for k in times.keys():
    plt.plot(sizes, times[k], marker="o", label=rf"{k}")
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Size")
plt.ylabel("Time [s]")
plt.legend()
plt.grid()
plt.show()

plt.plot(
    sizes,
    times["FastGeodis-generalised_geodesic2d-gpu"] / times["scipy-distance_transform_cdt-cpu"],
    marker="o",
    label=rf"{k}",
)
plt.plot(
    sizes,
    times["FastGeodis-generalised_geodesic2d-cpu"] / times["scipy-distance_transform_cdt-cpu"],
    marker="o",
    label=rf"{k}",
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Size")
plt.ylabel("Time / distance_transform_cdt-cpu")
plt.legend()
plt.grid()
plt.show()
masadcv commented 8 months ago

Hi, Thanks for raising this. There are different methods that can be used for implementing distance transform. For this FastGeodis library the primary focus has been to improve runtime for Geodesic distance transforms using a specific algorithm that can be parallelized on both CPU and GPU.

This library can also be used for computing Euclidean distance transforms, however since that has not been our focus to optimize (you can already find optimized implementations for it) therefore it does not come as surprise that existing methods are faster.

By the way, if you want to work with Euclidean distance transforms - it is recommended to use PixelQueue or Fast Marching methods as generalised_geodesic** functions may at best produce approximations of the Euclidean distance transform.

I am closing this issue as described above, it is not relevant for this library to optimize Euclidean distance transforms