tedwards2412 / ripple

Differentiable Gravitational Waveforms with JAX
51 stars 14 forks source link

Sine Gaussian Waveform #23

Closed ravioli1369 closed 1 month ago

ravioli1369 commented 1 month ago

This PR adds the SineGaussian waveform in ripple, along with a detailed python notebook showing the mismatch between the LALInference and Jax implementations.

mcoughlin commented 1 month ago

@tedwards2412 I hope you don't mind us hopping on board. We envision this being helpful for PE follow-up to Burst searches. @ravioli1369 is a student working with me, @ThibeauWouters, and others.

ThibeauWouters commented 1 month ago

Hi all,

Thanks a lot for the contributions, and great to see the extensive checks of the code! I quickly went over this and will jot down some comments from looking at the main py file, @tedwards2412 can indicate whether he agrees or not:

ravioli1369 commented 1 month ago

I have made most of the changes, but when I removed 64-bit precision from my notebook and ran it again, I noticed a large change in the accuracy of the implementation. Following that, I added a comparison between 64 and 32 bit precision waveforms and how they compared to the LALInference implementation (refer code cells just below the markdown heading Vary each parameter independently by fixing others for each 64 and 32 bit case, as well as the Conclusion at the very end). Is this sort of variation expected?

As for the jax.vmap implementation, initially I was trying to port the code 1:1 from the existing torch implementation of sine gaussains in the ml4gw repo (https://github.com/ML4GW/ml4gw/blob/main/ml4gw/waveforms/sine_gaussian.py). I'll look into using vmaps and hopefully commit that soon enough.

kazewong commented 1 month ago

Regarding float64 vs float32, I think one has to scale the signal respectively to avoid loss of accuracy. I don't see why the sine Gaussian will need float64 accuracy, so it should be okay to refactor it into float32. @tedwards2412 can comment more on this.

A side note going forward is I think we should start incorporating tests in the code base as it grows

tedwards2412 commented 1 month ago

Sorry for the delay, this looks great so far! @mcoughlin no problem at all, I'm just happy that people are starting to find the code useful and want to contribute :)

Couple of comments

  1. I think that to keep function signatures similar across waveforms we should try to always have gen_X_hphc(f, theta, kwargs) where f is the frequency grid and theta are the parameters associated with that particular waveform. Can we change the function to have a similar structure here? I'm not super familiar with the SineGaussian waveform structure though so don't know best how to do this.
  2. You indeed shouldn't need to do any reshaping of the arrays, you should simply use the vmap if you want to evaluate waveforms in parallel and this will add the batch dimension for you.
  3. It's to be expected that dropping to float32 precision will truncate some of the accuracy. What's likely happening is that some quantities are going outside the total range supported by float32. For example in the match, when you multiply the two waveforms in the numerator, it goes outside of the supported range so clips to 0. Maybe you're finding something similar to this? It shouldn't be a problem as long as you can scale everything accordingly. We by default therefore don't want to fix float64 in case the user wants to use float32. We might want to add a warning though so that the user is aware of this?
ravioli1369 commented 1 month ago

I had a few doubts regarding this:

  1. I'm also not sure how to go about constructing the waveform in the frequency domain. The waveform was meant to be a 1:1 implementation from LALInference, and thus is in the time domain. Converting ours to a frequency implementation would make it quite different from how it is called using LAL. This could lead to confusion for those familiar with LAL?

  2. I'm a bit new to this, so to clarify: I can remove the reshaping lines and instead call the gen_SineGaussian function itself with a vmap if I need to do batched computations?

  3. I'm not sure what is meant by scaling the waveforms, but I have done something to try and reduce the mismatch.

hcross_ripple.append(hcross[i][start:stop]*scale) hplus_ripple.append(hplus[i][start:stop]*scale) hcross_lal.append(hcross_*scale) hplus_lal.append(hplus_*scale)

I pushed the results back to the notebook towards the end (https://github.com/ravioli1369/ripple/blob/sine-gaussian/notebooks/check_SineGaussian.ipynb). Is this correct, or should I have done something else?

ThibeauWouters commented 1 month ago

@ravioli1369 @tedwards2412 Perhaps it might be good to start dividing up the ripple source code into frequency domain and time domain waveforms? I believe that time domain waveforms will get more supported in the future for other use cases as well. Thomas can indicate whether he agrees with this.

tedwards2412 commented 1 month ago
  1. Ahh sorry @ravioli1369, I didn't realize it was a time domain waveform. In this case, I still think the function signature should be gen_X_hphc(t, theta, kwargs) where t is the frequency grid. Overall the goal of ripple is not meant to be a 1:1 recreation of LAL in jax, we try to follow sensible design choices that are easy to use and provide some utility. The reason we like to package up the parameters into a single theta is because typically you'd like to do autodiff with respect to all these parameters and this provides a clean way to do that.
  2. Yeah you should just remove the reshaping lines and instead do vmap(func) which will return a new function which can be given a grid of parameters of shape (n_batch, n_theta). Here is an example of how this works for a different waveform: https://github.com/tedwards2412/ripple/blob/f297a225b920b8fa46cf284fb9311498a1e4495a/test/old_tests/benchmark_waveform.py#L82
  3. What you did looks sensible to me. All I mean by scaling is you multiply the waveform by an arbitrary constant to make sure all the values used in any computations stay within the range allowed by float32 and then scale back at the end if necessary. For example in the match you have a quantity like h*h/PSD. If you just multiply the PSD and individual h's by a factor of 1e20 before calculating the match, the ratio is unchanged by all the numbers remain with the float32 range. Does this make more sense?

@ThibeauWouters I agree that splitting the waveforms makes sense and will probably make things easier to track as things grow.

ravioli1369 commented 1 month ago
  1. Understood, I've made the necessary change.

    t is the frequency grid

    I assume this is supposed to be time grid?

  2. I had done a speed comparison between regular jax, jit, and vmap here (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) and found that vmap is ~2 times slower than running the function without it. I'm not fully sure as to why that is but I removed the reshape lines and tried this test again and found similar results: Regular Jax 10.9 ms ± 163 µs per loop (mean ± std. dev. of 100 runs, 10 loops each) Jax with JIT 1.49 ms ± 47.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each) Using vmap 19.8 ms ± 766 µs per loop (mean ± std. dev. of 100 runs, 10 loops each) Using vmap with JIT 1.49 ms ± 47.4 µs per loop (mean ± std. dev. of 100 runs, 100 loops each) A thing to note is that the performance of vmap+jit and jitting the regular function are identical. Is it worth removing the reshaping in favor of vmaps in this case?

  3. I understand how during the calculation of mismatch the float32 could go out of range, but I don't think that's what's happening here, and the difference would just get scaled back to the original values if I scale the waveforms again.

tedwards2412 commented 1 month ago
  1. Ok great, thanks this looks much cleaner now. I've merged the current version. And yes, that was supposed to say time grid :)

  2. For the timing, I'm not exactly sure what is going on here but a few things could be going wrong. Firstly, especially when running on a GPU, you need to make sure to call block_until_ready() to ensure that the computation is actually complete. See here for more details: https://jax.readthedocs.io/en/latest/faq.html. Overall though, vmap should be the default choice for vectorizing in Jax and this shouldn't be done manually. It's much cleaner to just use vmap and I suspect that it will also be faster once the timings are sorted out.

  3. Ok let me have a look in more detail when I have time next week and I can see if it's clear what is going on. There will obviously be some loss in precision going to float32 but I would have guessed it's not an issue at current detector sensitivities. In my tests on the match I definitely found a reduction of a few orders of magnitude in accuracy but it was still well below detectable levels.

ravioli1369 commented 1 month ago

I ran the timing benchmarks again with the block_until_ready() command and found the results to be similar to before. The vmapped version is ~2 times slower. This again has no impact on the jitted versions of both (vmapped and non vmapped) functions.

Regular Jax 13.5 ms ± 204 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)

Jax with JIT 1.77 ms ± 24.9 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)

Using vmap 25.2 ms ± 273 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)

Using vmap with JIT 1.75 ms ± 41.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)

The notebook (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) has more details on how I ran the benchmark. I even changed the output of the sine gaussian function to give a single array so that I don't have to evaluate it in a list comprehension, but the results of that were also similar.

tedwards2412 commented 1 month ago

I tried myself with the merged version of the SineGaussian waveform in ripple and find that vmap is instead quicker. Here is the code:

import jax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ripplegw.waveforms import SineGaussian
import time
from jax import vmap

duration = 10.0
sampling_frequency = 4096
dt = 1 / sampling_frequency
times = jnp.arange(-duration / 2, duration / 2, dt)
print(times.shape)

n_waveforms = 1000

quality = jnp.linspace(3, 100, n_waveforms)
frequency = jnp.logspace(1, 3, n_waveforms)
hrss = jnp.logspace(-23, -6, n_waveforms)
phase = jnp.linspace(0, 2 * np.pi, n_waveforms)
eccentricity = jnp.linspace(0, 0.99, n_waveforms)

theta_ripple = np.array([quality, frequency, hrss, phase, eccentricity]).T

print(theta_ripple.shape)

@jax.jit
def waveform(theta):
    return SineGaussian.gen_SineGaussian_hphc(times, theta)

print("JIT compiling")
waveform(theta_ripple[0])[0].block_until_ready()
print("Finished JIT compiling")

start = time.time()
for t in theta_ripple:
    waveform(t)[0].block_until_ready()
end = time.time()

print("Ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms))

func = vmap(waveform)
func(theta_ripple)[0].block_until_ready()

start = time.time()
hp_batch = func(theta_ripple)[0].block_until_ready()
end = time.time()

print(
    "Vmapped ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms)
)

print(hp_batch.shape)

My output gives:

> (1000, 5)
> JIT compiling
> Finished JIT compiling
> Ripple waveform call takes: 0.189530 ms
> Vmapped ripple waveform call takes: 0.071345 ms
> (1000, 40960)

Note I just did this directly on my laptop on the CPU so the effect will be further magnified when using a GPU. Also, I had to remove the reshapes which now incorrectly add dimensions once you use vmap. The resulting array should be shape (batch_dimension, time_grid).

ravioli1369 commented 1 month ago

It does indeed look like vmap is faster than running through the parameters in a for loop. The way I tested it was to send all the parameters into the function and call reshape inside of it; this gave results that were faster than removing the reshape and running vmap. I'm not sure why this is happening. The jitted versions (with and without reshape) give identical results, so I think it should be fine to leave it this way, although it does warrant some investigation to see why vmap is performing worse than reshaping.

tedwards2412 commented 1 month ago

I think this overall makes sense, once you add the jit and your manual reshaping I think this is basically manually vectorizing the function and so it should perform similarly to vmap + jit. Overall though, it's not good practice in Jax to add this kind of manual reshaping when you can instead use vmap :) vmapping doesn't do the jit for you, so this is required to make it fast!