aperiosoftware / astropy

Repository for the Astropy core package
www.astropy.org
BSD 3-Clause "New" or "Revised" License
0 stars 0 forks source link

Implement parallel fitting #61

Closed astrofrog closed 3 months ago

astrofrog commented 3 months ago

This currently implements a single function that can be used to fit models to N-dimensional cubes. At the moment the internal code is very much a prototype and doesn't implement a number of features. But it should in principle already be possible to pass in an N-dimensional cube and fit an M-dimensional model to a subset of the axes, for example fit a 1D model to spectra in a spectral cube, or 2D Gaussians to the celestial axes and so on.

Simple example

    import numpy as np
    from astropy.modeling.models import Gaussian1D
    from astropy.modeling.fitting import LMLSQFitter
    from astropy.modeling.fitting_parallel import parallel_fit_model_nd

    NWAV = 100
    NMOD = 200

    wav = np.linspace(10, 20, NWAV)
    amplitude = np.linspace(1, 2, NMOD)
    mean = np.linspace(12, 18, NMOD)
    sigma = np.linspace(1, 2, NMOD)

    data = amplitude * np.exp(-(wav[:, None] - mean) ** 2 / 2 / sigma **2)
    data += np.random.random(data.shape) / 5

    g = Gaussian1D(amplitude=2, mean=np.ones(NMOD), stddev=np.ones(NMOD))

    from tqdm.dask import TqdmCallback
    with TqdmCallback(desc="fitting"):

        g_fit = parallel_fit_model_nd(
            model=g,
            fitter=LMLSQFitter(),
            data=data,
            fitting_axes=0,
            world={0: wav}
        )

    import matplotlib.pyplot as plt

    fig = plt.figure()
    ax = fig.add_subplot(3, 1, 1)
    ax.text(0.05, 0.8, 'amplitude', transform=ax.transAxes)
    ax.plot(amplitude)
    ax.plot(g_fit.amplitude.value)
    ax = fig.add_subplot(3, 1, 2)
    ax.text(0.05, 0.8, 'mean', transform=ax.transAxes)
    ax.plot(mean)
    ax.plot(g_fit.mean.value)
    ax = fig.add_subplot(3, 1, 3)
    ax.text(0.05, 0.8, 'stddev', transform=ax.transAxes)
    ax.plot(sigma)
    ax.plot(g_fit.stddev.value)
    fig.savefig('results.png')

results

Things that still need to be done

Related issues/PRs:

Open questions:

Beyond this PR:

github-actions[bot] commented 3 months ago

Thank you for your contribution to Astropy! 🌌 This checklist is meant to remind the package maintainers who will review this pull request of some common things to look for.

github-actions[bot] commented 3 months ago

👋 Thank you for your draft pull request! Do you know that you can use [ci skip] or [skip ci] in your commit messages to skip running continuous integration tests until you are ready?

astrofrog commented 3 months ago

Might also want to have a way to specify a mask to say which pixels to fit vs ignore

astrofrog commented 3 months ago

Closing as the astropy PR is now open.