AlecThomson / RACS-tools

Useful scripts for RACS
BSD 3-Clause "New" or "Revised" License
12 stars 7 forks source link

Migrate from Fortran #40

Closed AlecThomson closed 11 months ago

AlecThomson commented 1 year ago

Building the Fortran code seems to be a common source of installation woes for users (including myself!). @wasimraja81 do you think its possible to migrate the Fortran to a Pythonic version? I'm happy to sketch some options out here and then we can test a possible migration.

If we can't make it fast with a Numpy implementation, we could consider a Numba-compiled version

AlecThomson commented 11 months ago

@wasimraja81 here's my attempt:

Implementation:

import racs_tools.gaussft as gaussft
from tqdm import tqdm, trange
import numpy as np
from typing import Tuple
import numba as nb

def gaussft_vec(
    bmin_in: float,
    bmaj_in: float,
    bpa_in: float,
    bmin: float,
    bmaj: float,
    bpa: float,
    u: np.ndarray,
    v: np.ndarray,
) -> Tuple[np.ndarray, float]:
    """
    Compute the Fourier transform of a 2D Gaussian for convolution.

    Parameters:
        bmin_in (float): Intrinsic psf BMIN (degrees)
        bmaj_in (float): Intrinsic psf BMAJ (degrees)
        bpa_in (float): Intrinsic psf BPA (degrees)
        bmin (float): Final psf BMIN (degrees)
        bmaj (float): Final psf BMAJ (degrees)
        bpa (float): Final psf BPA (degrees)
        u (np.ndarray): Fourier coordinates corresponding to image coord x
        v (np.ndarray): Fourier coordinates corresponding to image coord y

    Returns:
        g_final (np.ndarray): Final array to be multiplied to FT(image) for convolution in the FT domain.
        g_ratio (float): Factor for flux scaling
    """
    deg2rad = np.pi / 180.0

    bmaj_in_rad, bmin_in_rad, bpa_in_rad = (
        bmaj_in * deg2rad,
        bmin_in * deg2rad,
        bpa_in * deg2rad,
    )
    bmaj_rad, bmin_rad, bpa_rad = bmaj * deg2rad, bmin * deg2rad, bpa * deg2rad

    sx, sy = bmaj_rad / (2 * np.sqrt(2.0 * np.log(2.0))), bmin_rad / (
        2 * np.sqrt(2.0 * np.log(2.0))
    )
    sx_in, sy_in = bmaj_in_rad / (2.0 * np.sqrt(2.0 * np.log(2.0))), bmin_in_rad / (
        2.0 * np.sqrt(2.0 * np.log(2.0))
    )

    u_cosbpa, u_sinbpa = u * np.cos(bpa_rad), u * np.sin(bpa_rad)
    u_cosbpa_in, u_sinbpa_in = u * np.cos(bpa_in_rad), u * np.sin(bpa_in_rad)

    v_cosbpa, v_sinbpa = v * np.cos(bpa_rad), v * np.sin(bpa_rad)
    v_cosbpa_in, v_sinbpa_in = v * np.cos(bpa_in_rad), v * np.sin(bpa_in_rad)

    g_amp = np.sqrt(2.0 * np.pi * sx * sy)

    dg_amp = np.sqrt(2.0 * np.pi * sx_in * sy_in)

    g_ratio = g_amp / dg_amp

    # Vectorized calculation of ur, vr, g_arg, and dg_arg
    ur = u_cosbpa[:, np.newaxis] - v_sinbpa[np.newaxis, :]
    vr = u_sinbpa[:, np.newaxis] + v_cosbpa[np.newaxis, :]
    g_arg = -2.0 * np.pi**2 * ((sx * ur) ** 2 + (sy * vr) ** 2)

    ur_in = u_cosbpa_in[:, np.newaxis] - v_sinbpa_in[np.newaxis, :]
    vr_in = u_sinbpa_in[:, np.newaxis] + v_cosbpa_in[np.newaxis, :]
    dg_arg = -2.0 * np.pi**2 * ((sx_in * ur_in) ** 2 + (sy_in * vr_in) ** 2)

    # Vectorized calculation of g_final
    g_final = g_ratio * np.exp(g_arg - dg_arg)

    return g_final, g_ratio

gaussft_numba = nb.njit(gaussft_vec, cache=True)

Testing:


ntest = 1000
bmajs = np.random.uniform(11, 50, ntest)
bmins = np.random.uniform(1, 10, ntest)
bpas = np.random.uniform(-90, 90, ntest)
u = np.arange(10)
v = np.arange(10)
for i in trange(ntest):
    args = (
        bmajs[i],
        bmins[i],
        bpas[i],
        bmajs[i]+5,
        bmins[i]+5,
        bpas[i],
        u,
        v,
    )
    g_final, g_ratio = gaussft.gaussft(*args)
    g_final_vec, g_ratio_vec = gaussft_vec(*args)
    g_final_numba, g_ratio_numba = gaussft_numba(*args)
    assert np.isclose(g_final, g_final_vec).all()
    assert np.isclose(g_ratio, g_ratio_vec).all()
    assert np.isclose(g_final, g_final_numba).all()
    assert np.isclose(g_ratio, g_ratio_numba).all()
# passes

Speed tests:

Fortran (current):

%%timeit
g_final, g_ratio =  gaussft.gaussft(*args)
# 1.95 µs ± 9.94 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Pure Python/numpy:

%%timeit
g_final_vec, g_ratio_vec = gaussft_vec(*args)
# 34.5 µs ± 452 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Numba JIT:

%%timeit
g_final_numba, g_ratio_numba = gaussft_numba(*args)
# 3.14 µs ± 23.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
AlecThomson commented 11 months ago

Have a working branch in #50