GalacticDynamics / galax

Galactic and Gravitational Dynamics in Python (+ GPU and autodiff)
MIT License
27 stars 5 forks source link

feat: re-vectorize #461

Closed nstarman closed 2 weeks ago

nstarman commented 2 weeks ago

Achieve much closer to native diffrax speeds! Also speeds up the doc tests and simplifies the vectorization structure and the interpolation structure.

Followup:

  1. What triggers the random 100s integrations? This appears to be something in diffrax / what is passed to diffrax
  2. Calls to .w(units=...) are SLOW. Figuring out ways to work with PSPs as PyTrees instead of doing PSP -> array -> PSP through .w() would make this faster.
  3. Point 2 may be why in gd.evaluate_orbit the jitted code ends up slower for > 10^6 saved times. But this should be figured out and fixed.
codecov[bot] commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 98.00000% with 1 line in your changes missing coverage. Please review.

Project coverage is 96.54%. Comparing base (9fe940d) to head (4cbcf41). Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/galax/dynamics/_src/integrate/core.py 97.50% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #461 +/- ## ========================================== - Coverage 96.60% 96.54% -0.06% ========================================== Files 77 78 +1 Lines 3030 3007 -23 ========================================== - Hits 2927 2903 -24 - Misses 103 104 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

nstarman commented 2 weeks ago

Some diagnostics:

CODE ```python import time import jax import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np import tqdm from plum import convert import coordinax as cx import quaxed.numpy as jnp from dataclassish import replace from unxt import AbstractQuantity, Quantity import galax.coordinates as gc import galax.dynamics as gd import galax.potential as gp pot = gp.KeplerPotential(m_tot=Quantity(1e12, "Msun"), units="galactic") w0 = gc.PhaseSpacePosition( q=Quantity([8.0, 0.0, 0.0], "kpc"), p=Quantity([0.0, 230, 0.0], "km/s"), t=Quantity(0, "Myr"), ) numts = [1, 2, *jnp.geomspace(10, 1e8 + 1, num=10, dtype=int)] integrator = gd.integrate.Integrator(diffeq_kw={"max_steps": 30_000}) w0_arr = w0.w(units=pot.units) times1 = [] for n in tqdm.tqdm(numts): time_start = time.time_ns() _ = integrator( pot._dynamics_deriv, w0_arr, Quantity(0.0, "Gyr"), Quantity(1.0, "Gyr"), units=pot.units, saveat=Quantity(jnp.linspace(0.3, 1, int(n)), "Gyr"), ) time_end = time.time_ns() times1.append(time_end - time_start) times2 = [] for n in tqdm.tqdm(numts): time_start = time.time_ns() _ = integrator( pot._dynamics_deriv, w0_arr, Quantity(0.0, "Gyr"), Quantity(1.0, "Gyr"), units=pot.units, saveat=Quantity(jnp.linspace(0.3, 1, int(n)), "Gyr"), ) time_end = time.time_ns() times2.append(time_end - time_start) # Plot fig, ax = plt.subplots() ax.scatter(np.array(numts), np.array(times1) / 1e9, label="First run") ax.scatter(np.array(numts), np.array(times2) / 1e9, label="Second run") ax.set(xlabel="Number of timesteps", ylabel="Time (s)", xscale="log", yscale="log") ax.legend() ax.yaxis.set_major_locator(mpl.ticker.LogLocator(base=10.0)) ax.yaxis.set_minor_locator( mpl.ticker.LogLocator(base=10.0, subs=np.arange(1.0, 10.0, 1.0), numticks=10) ) ax.yaxis.set_major_formatter(mpl.ticker.LogFormatter()) ax.yaxis.set_minor_formatter(mpl.ticker.LogFormatter(minor_thresholds=(2, 0))) ax.minorticks_on() fig.show(); times1 = [] times2 = [] for n in tqdm.tqdm(numts): time_start = time.time_ns() _ = gd.evaluate_orbit(pot, w0_arr, Quantity(jnp.linspace(0.3, 1, int(n)), "Gyr")) time_end = time.time_ns() times1.append(time_end - time_start) time_start = time.time_ns() _ = gd.evaluate_orbit(pot, w0_arr, Quantity(jnp.linspace(0.3, 1, int(n)), "Gyr")) time_end = time.time_ns() times2.append(time_end - time_start) # Plot fig, ax = plt.subplots() ax.scatter(np.array(numts), np.array(times1) / 1e9, label="First run") ax.scatter(np.array(numts), np.array(times2) / 1e9, label="Second run") ax.set(xlabel="Number of timesteps", ylabel="Time (s)", xscale="log", yscale="log") ax.legend() ax.yaxis.set_major_locator(mpl.ticker.LogLocator(base=10.0)) ax.yaxis.set_minor_locator( mpl.ticker.LogLocator(base=10.0, subs=np.arange(1.0, 10.0, 1.0), numticks=10) ) ax.yaxis.set_major_formatter(mpl.ticker.LogFormatter()) ax.yaxis.set_minor_formatter(mpl.ticker.LogFormatter(minor_thresholds=(2, 0))) ax.minorticks_on() fig.show(); ```

Calling Integrator()

CleanShot 2024-09-21 at 15 24 17@2x

Calling evaluate_orbit

CleanShot 2024-09-21 at 15 24 46@2x