mehta-lab / waveorder

Wave optical models and inverse algorithms for label-agnostic imaging of density & orientation.
BSD 3-Clause "New" or "Revised" License
12 stars 3 forks source link

Update background estimation to use torch #153

Closed ziw-liu closed 6 months ago

ziw-liu commented 6 months ago

Introduced a new module waveorder.correction to replace the waveorder.background_estimator.

Consistency

The new method produces the same result:

import matplotlib.pyplot as plt
import torch

from waveorder.background_estimator import BackgroundEstimator2D
from waveorder.correction import estimate_background

# make example image
image = torch.zeros(360, 480)
image[:180, :] += 1
image[:, 120:360] += 1
image += torch.rand_like(image) * 2
plt.imshow(image)

image

f, ax = plt.subplots(2, 4, figsize=(12, 6))

for i in range(4):
    new_surface = estimate_background(image, order=i + 1, block_size=32)
    old_surface = BackgroundEstimator2D(block_size=32).get_background(
        image, order=i + 1, normalize=False
    )
    ax[0, i].imshow(new_surface)
    ax[0, i].set_title(f"torch, order={i+1}")
    ax[0, i].axis("off")
    ax[1, i].imshow(old_surface)
    ax[1, i].set_title(f"numpy, order={i+1}")
    ax[1, i].axis("off")

f.tight_layout()

image

Speed

Test on a large image:

im = torch.rand(2048, 2048)
np_im = large_image.numpy()
cuda_im = large_image.to("cuda")

NumPy (AMD EPYC 7302P CPU):

BackgroundEstimator2D(block_size=32).get_background(np_im, order=2, normalize=False)
# 293 ms ± 844 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

PyTorch implementation sees a 4x speed up on CPU:

estimate_background(im, order=2, block_size=32)
# 68.6 ms ± 4.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

An NVIDIA A40 GPU can provide 8x extra acceleration, or 35x faster compared to NumPy:

estimate_background(cuda_im, order=2, block_size=32)
# 8.31 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
ziw-liu commented 6 months ago

I see you dropped the normalize option (which was set to True) by default but seems unused---any idea what this option was doing?

It rescales the output to unit mean. In all the places using it it's set to False.