kaanaksit / odak

Scientific computing library for optics, computer graphics and visual perception.
https://kaanaksit.com/odak
Mozilla Public License 2.0
176 stars 52 forks source link

Update function "custom" in classical.py #101

Closed WeijieXie closed 4 months ago

WeijieXie commented 4 months ago

remove "torch.fft.fftshift" and "torch.fft.ifftshift" which may help accelerate the program

Validated by the following 2 tests

When the number of the sampling points is odd

import torch

# odd
field_padded = torch.rand(3,3)
H =torch.rand(3,3)
aperture = torch.rand(3,3)

# original code
U1_1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field_padded)))
U2_1 = H * aperture * U1_1
result_1 = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2_1)))

# remove the fftshift and ifftshift in the spatial domain
U1_2 = torch.fft.fftshift(torch.fft.fft2(field_padded))
U2_2 = H * aperture * U1_2
result_2 = torch.fft.ifft2(torch.fft.ifftshift(U2_2))

print(U1_1)
print(U1_2)

assert torch.allclose(result_1, result_2)

image

When the number of the sampling points is even

import torch

# even
field_padded = torch.rand(4,4)
H =torch.rand(4,4)
aperture = torch.rand(4,4)

# original code
U1_1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field_padded)))
U2_1 = H * aperture * U1_1
result_1 = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2_1)))

# remove the fftshift and ifftshift in the spatial domain
U1_2 = torch.fft.fftshift(torch.fft.fft2(field_padded))
U2_2 = H * aperture * U1_2
result_2 = torch.fft.ifft2(torch.fft.ifftshift(U2_2))

print(U1_1)
print(U1_2)

assert torch.allclose(result_1, result_2)

image

kaanaksit commented 4 months ago

Thank you, @WeijieXie, for this suggestion. Is there any chance you could give me an idea on speed up in this case? Perhaps, you may want to try these two using various matrix sizes (i.e., 1920 by 1080) over a number of trial (i.e, 100) to get a sense.

WeijieXie commented 4 months ago

Thanks for the reviewing, Kaan

def cpu_timer(operation,reapeat=100): totaltime = 0 for in range(reapeat): start_time = time.perf_counter() operation() end_time = time.perf_counter() total_time += end_time - start_time return total_time/reapeat

def gpu_timer(operation,reapeat=100): totaltime = 0 for in range(reapeat): torch.cuda.synchronize() start_time = torch.cuda.Event(enable_timing=True) end_time = torch.cuda.Event(enable_timing=True)

    start_time.record()
    operation()
    end_time.record()

    torch.cuda.synchronize()
    total_time += start_time.elapsed_time(end_time)
return total_time/reapeat

def original_code(field_padded, H, aperture): U1 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(field_padded))) U2 = H aperture U1 return torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(U2)))

def modified_code(field_padded, H, aperture): U1 = torch.fft.fftshift(torch.fft.fft2(field_padded)) U2 = H aperture U1 return torch.fft.ifft2(torch.fft.ifftshift(U2))

comparison on cpu

field_padded = torch.rand(1920,1080) H =torch.rand(1920,1080) aperture = torch.rand(1920,1080)

original_time = cpu_timer(lambda: original_code(field_padded, H, aperture)) modified_time = cpu_timer(lambda: modified_code(field_padded, H, aperture)) print(f'Original code on CPU: {original_time:.8f} ms') print(f'Modified code on CPU: {modified_time:.8f} ms')

comparison on gpu

field_padded = field_padded.to('cuda') H = H.to('cuda') aperture = aperture.to('cuda')

original_time = gpu_timer(lambda: original_code(field_padded, H, aperture)) modified_time = gpu_timer(lambda: modified_code(field_padded, H, aperture)) print(f'Original code on GPU: {original_time:.8f} ms') print(f'Modified code on GPU: {modified_time:.8f} ms')

- on the cpu of 9750H and GPU of RTX1650, the result is as followed:
```bash
Original code on CPU: 0.03148508 ms
Modified code on CPU: 0.02354449 ms
Original code on GPU: 7.29974783 ms
Modified code on GPU: 4.84190688 ms

image Both seems considerably faster

kaanaksit commented 4 months ago

Thank you, @WeijieXie ! I have incorporated your changes in the repository. For now, users has to install odak from the repository to benefit from your changes. But it will be reflected to the pip version as I release odak==0.2.6.

I have also included your name to our THANKS.txt and CITATION.cff, please have a visual check. If you encounter any missing information, please do not hesitate to let me know. I haven't added any ORCID number for your name in the CITATION.cff, if there is any please let me know so that I can edit accordingly.