SNEWS2 / snewpy

A Python package for working with supernova neutrinos
https://snewpy.readthedocs.io
BSD 3-Clause "New" or "Revised" License
24 stars 17 forks source link

times broadcasting for Fornax_2021 #310

Closed JostMigenda closed 3 months ago

JostMigenda commented 3 months ago

Fixes #259. Additionally fixes a bug I discovered during this work, where for interpolation="nearest" a scalar value for E (e.g. 10*u.MeV) would crash and a length-1 array (e.g. [10]*u.MeV) was required instead.

In addition to the usual tests, I’ve verified manually that the outputs for a few example t/E values were unchanged before/after this change. Additionally, I ran some performance checks, comparing the old and new versions with the following code snippet:

from snewpy.neutrino import Flavor
from snewpy.models.ccsn import Fornax_2021
from astropy import units as u
import numpy as np

fornax = Fornax_2021(progenitor_mass=Fornax_2021.param['progenitor_mass'][1])
n_times = 100  # set this to 1, 10, 100 for benchmarking
times = np.linspace(1, 2, n_times) * u.s

# old: manually loop over times
for t in times:
    fornax.get_initial_spectra(t=t, E=list(range(100))*u.MeV, flavors=[Flavor.NU_E], interpolation="nearest")[Flavor.NU_E]

# this PR
fornax.get_initial_spectra(t=times, E=list(range(100))*u.MeV, flavors=[Flavor.NU_E], interpolation="nearest")[Flavor.NU_E]

Overall, this PR adds about ~10% overhead for a single time value, but for multiple time values it grows much less than linearly, so we’d see a performance boost for two or more time values, with about an order of magnitude better performance for 100 time values.

Benchmark results for interpolation = "linear" Here, the length of the energies array passed in had no significant impact on results.

n_times old this PR
1 0.62 ms 0.69 ms
10 6.2 ms 1.2 ms
100 61 ms 5.5 ms

Benchmark results for interpolation = "nearest" Here, the range in each cell shows the impact of using an energy array of length 1–100 for a given n_times. There’s likely some room for improvement here (as you’ll see in the diff, for one line I gave up and used a for loop over times, rather than spending even more time trying to figure out how to use NumPy broadcasting), especially for large arrays of energies and times; but this PR still a major improvement over the status quo and I don’t think it’s worth investing more time into this. (interpolation = "nearest" is a non-default value with little physics motivation; I doubt it’s used much.)

n_times old this PR
1 0.55 – 0.65 ms 0.59 – 0.70 ms
10 5.1 – 6.2 ms 0.97 – 2.0 ms
100 51 – 61 ms 4.3 – 15 ms