mehta-lab / waveorder

Wave optical models and inverse algorithms for label-agnostic imaging of density & orientation.
BSD 3-Clause "New" or "Revised" License
15 stars 4 forks source link

Device agnostic compute for polarization #150

Closed ziw-liu closed 9 months ago

ziw-liu commented 1 year ago

Fixed the stokes module and its tests to work both on CPU and GPU.

Caveat: the background estimation in waveorder.models.inplane_oriented_thick_pol3D.apply_inverse_transfer_function ~is still NumPy code~ does not work with the MPS backend. See #153.

Tested on cuda (NVIDIA A40, CUDA 12.2; AMD EPYC 7302P, Linux 4.18.0) and mps (Apple M1 Pro, macOS 13.5.1), both with native PyTorch build from PyPI.

tests/test_stokes.py::test_S2I_matrix PASSED                                                                    [  5%]
tests/test_stokes.py::test_I2S_matrix PASSED                                                                    [ 11%]
tests/test_stokes.py::test_s12_to_orientation[cpu] PASSED                                                       [ 17%]
tests/test_stokes.py::test_s12_to_orientation[cuda] PASSED                                                      [ 23%]
tests/test_stokes.py::test_stokes_recon[cpu] PASSED                                                             [ 29%]
tests/test_stokes.py::test_stokes_recon[cuda] PASSED                                                            [ 35%]
tests/test_stokes.py::test_stokes_after_adr_usage PASSED                                                        [ 41%]
tests/test_stokes.py::test_mueller_from_stokes PASSED                                                           [ 47%]
tests/test_stokes.py::test_mmul[cpu] PASSED                                                                     [ 52%]
tests/test_stokes.py::test_mmul[cuda] PASSED                                                                    [ 58%]
tests/test_stokes.py::test_copying[cpu] PASSED                                                                  [ 64%]
tests/test_stokes.py::test_copying[cuda] PASSED                                                                 [ 70%]
tests/test_stokes.py::test_orientation_offset[cpu] PASSED                                                       [ 76%]
tests/test_stokes.py::test_orientation_offset[cuda] PASSED                                                      [ 82%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_calculate_transfer_function PASSED                      [ 88%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_apply_inverse_transfer_function[cpu] PASSED             [ 94%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_apply_inverse_transfer_function[cuda] PASSED            [100%]
tests/test_stokes.py::test_S2I_matrix PASSED                                                                                                                                                                                                                [  5%]
tests/test_stokes.py::test_I2S_matrix PASSED                                                                                                                                                                                                                [ 11%]
tests/test_stokes.py::test_s12_to_orientation[cpu] PASSED                                                                                                                                                                                                   [ 17%]
tests/test_stokes.py::test_s12_to_orientation[mps] PASSED                                                                                                                                                                                                   [ 23%]
tests/test_stokes.py::test_stokes_recon[cpu] PASSED                                                                                                                                                                                                         [ 29%]
tests/test_stokes.py::test_stokes_recon[mps] PASSED                                                                                                                                                                                                         [ 35%]
tests/test_stokes.py::test_stokes_after_adr_usage PASSED                                                                                                                                                                                                    [ 41%]
tests/test_stokes.py::test_mueller_from_stokes PASSED                                                                                                                                                                                                       [ 47%]
tests/test_stokes.py::test_mmul[cpu] PASSED                                                                                                                                                                                                                 [ 52%]
tests/test_stokes.py::test_mmul[mps] PASSED                                                                                                                                                                                                                 [ 58%]
tests/test_stokes.py::test_copying[cpu] PASSED                                                                                                                                                                                                              [ 64%]
tests/test_stokes.py::test_copying[mps] PASSED                                                                                                                                                                                                              [ 70%]
tests/test_stokes.py::test_orientation_offset[cpu] PASSED                                                                                                                                                                                                   [ 76%]
tests/test_stokes.py::test_orientation_offset[mps] PASSED                                                                                                                                                                                                   [ 82%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_calculate_transfer_function PASSED                                                                                                                                                                  [ 88%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_apply_inverse_transfer_function[cpu] PASSED                                                                                                                                                         [ 94%]
tests/models/test_inplane_oriented_thick_pol3D.py::test_apply_inverse_transfer_function[mps] PASSED                                                                                                                                                         [100%]
ziw-liu commented 1 year ago

A step towards #144.

ziw-liu commented 1 year ago

Also tested the cuda backend with an AMD GPU (RX6800XT, ROCm 5.6.1, Linux 6.5.3), although it probably won't be officially supported by us. This needs a special PyTorch build installed before waveorder.

pip install torch --index-url https://download.pytorch.org/whl/rocm5.4.2
ziw-liu commented 9 months ago

Speed comparison:

import torch
from waveorder.models import inplane_oriented_thick_pol3d

def test_apply_inverse_transfer_function(device):
    input_shape = (5, 100, 2048, 2048)
    czyx_data = torch.rand(input_shape, device=device)

    intensity_to_stokes_matrix = (
        inplane_oriented_thick_pol3d.calculate_transfer_function(
            swing=0.1,
            scheme="5-State",
        ).to(device)
    )

    _ = inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
        czyx_data=czyx_data,
        intensity_to_stokes_matrix=intensity_to_stokes_matrix,
    )

AMD EPYC 7302P CPU:

test_apply_inverse_transfer_function("cpu")
# 11.6 s ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

A single NVIDIA A40 GPU is 60x faster (and also faster than a typical camera's framerate):

test_apply_inverse_transfer_function("cuda")
# 193 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)