Closed ravioli1369 closed 1 month ago
@tedwards2412 I hope you don't mind us hopping on board. We envision this being helpful for PE follow-up to Burst searches. @ravioli1369 is a student working with me, @ThibeauWouters, and others.
Hi all,
Thanks a lot for the contributions, and great to see the extensive checks of the code! I quickly went over this and will jot down some comments from looking at the main py file, @tedwards2412 can indicate whether he agrees or not:
jax_enable_x64
: it has been decided before to allow users to use float16 or float32, see #10 gen_SineGaussian
to gen_SineGaussian_hphc
to agree with other waveform files, and also make sure to return in that order (first plus, then cross)jax.vmap
functionalities without these extra lines. It would be great if someone can confirm that, either from more jax knowledge or by simply running this as an experiment.I have made most of the changes, but when I removed 64-bit precision from my notebook and ran it again, I noticed a large change in the accuracy of the implementation. Following that, I added a comparison between 64 and 32 bit precision waveforms and how they compared to the LALInference implementation (refer code cells just below the markdown heading Vary each parameter independently by fixing others
for each 64 and 32 bit case, as well as the Conclusion
at the very end).
Is this sort of variation expected?
As for the jax.vmap
implementation, initially I was trying to port the code 1:1 from the existing torch implementation of sine gaussains in the ml4gw repo (https://github.com/ML4GW/ml4gw/blob/main/ml4gw/waveforms/sine_gaussian.py). I'll look into using vmaps and hopefully commit that soon enough.
Regarding float64 vs float32, I think one has to scale the signal respectively to avoid loss of accuracy. I don't see why the sine Gaussian will need float64 accuracy, so it should be okay to refactor it into float32. @tedwards2412 can comment more on this.
A side note going forward is I think we should start incorporating tests in the code base as it grows
Sorry for the delay, this looks great so far! @mcoughlin no problem at all, I'm just happy that people are starting to find the code useful and want to contribute :)
Couple of comments
I had a few doubts regarding this:
I'm also not sure how to go about constructing the waveform in the frequency domain. The waveform was meant to be a 1:1 implementation from LALInference, and thus is in the time domain. Converting ours to a frequency implementation would make it quite different from how it is called using LAL. This could lead to confusion for those familiar with LAL?
I'm a bit new to this, so to clarify: I can remove the reshaping lines and instead call the gen_SineGaussian function itself with a vmap if I need to do batched computations?
I'm not sure what is meant by scaling the waveforms, but I have done something to try and reduce the mismatch.
hcross_ripple.append(hcross[i][start:stop]*scale)
hplus_ripple.append(hplus[i][start:stop]*scale)
hcross_lal.append(hcross_*scale)
hplus_lal.append(hplus_*scale)
I pushed the results back to the notebook towards the end (https://github.com/ravioli1369/ripple/blob/sine-gaussian/notebooks/check_SineGaussian.ipynb). Is this correct, or should I have done something else?
@ravioli1369 @tedwards2412 Perhaps it might be good to start dividing up the ripple source code into frequency domain and time domain waveforms? I believe that time domain waveforms will get more supported in the future for other use cases as well. Thomas can indicate whether he agrees with this.
@ThibeauWouters I agree that splitting the waveforms makes sense and will probably make things easier to track as things grow.
Understood, I've made the necessary change.
t is the frequency grid
I assume this is supposed to be time grid?
I had done a speed comparison between regular jax, jit, and vmap here (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb
) and found that vmap is ~2 times slower than running the function without it. I'm not fully sure as to why that is but I removed the reshape lines and tried this test again and found similar results:
Regular Jax
10.9 ms ± 163 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
Jax with JIT
1.49 ms ± 47.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
Using vmap
19.8 ms ± 766 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
Using vmap with JIT
1.49 ms ± 47.4 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
A thing to note is that the performance of vmap+jit and jitting the regular function are identical. Is it worth removing the reshaping in favor of vmaps in this case?
I understand how during the calculation of mismatch the float32 could go out of range, but I don't think that's what's happening here, and the difference would just get scaled back to the original values if I scale the waveforms again.
Ok great, thanks this looks much cleaner now. I've merged the current version. And yes, that was supposed to say time grid :)
For the timing, I'm not exactly sure what is going on here but a few things could be going wrong. Firstly, especially when running on a GPU, you need to make sure to call block_until_ready() to ensure that the computation is actually complete. See here for more details: https://jax.readthedocs.io/en/latest/faq.html. Overall though, vmap should be the default choice for vectorizing in Jax and this shouldn't be done manually. It's much cleaner to just use vmap and I suspect that it will also be faster once the timings are sorted out.
Ok let me have a look in more detail when I have time next week and I can see if it's clear what is going on. There will obviously be some loss in precision going to float32 but I would have guessed it's not an issue at current detector sensitivities. In my tests on the match I definitely found a reduction of a few orders of magnitude in accuracy but it was still well below detectable levels.
I ran the timing benchmarks again with the block_until_ready()
command and found the results to be similar to before. The vmapped version is ~2 times slower. This again has no impact on the jitted versions of both (vmapped and non vmapped) functions.
Regular Jax 13.5 ms ± 204 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
Jax with JIT 1.77 ms ± 24.9 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
Using vmap 25.2 ms ± 273 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
Using vmap with JIT 1.75 ms ± 41.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
The notebook (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) has more details on how I ran the benchmark. I even changed the output of the sine gaussian function to give a single array so that I don't have to evaluate it in a list comprehension, but the results of that were also similar.
I tried myself with the merged version of the SineGaussian waveform in ripple and find that vmap is instead quicker. Here is the code:
import jax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ripplegw.waveforms import SineGaussian
import time
from jax import vmap
duration = 10.0
sampling_frequency = 4096
dt = 1 / sampling_frequency
times = jnp.arange(-duration / 2, duration / 2, dt)
print(times.shape)
n_waveforms = 1000
quality = jnp.linspace(3, 100, n_waveforms)
frequency = jnp.logspace(1, 3, n_waveforms)
hrss = jnp.logspace(-23, -6, n_waveforms)
phase = jnp.linspace(0, 2 * np.pi, n_waveforms)
eccentricity = jnp.linspace(0, 0.99, n_waveforms)
theta_ripple = np.array([quality, frequency, hrss, phase, eccentricity]).T
print(theta_ripple.shape)
@jax.jit
def waveform(theta):
return SineGaussian.gen_SineGaussian_hphc(times, theta)
print("JIT compiling")
waveform(theta_ripple[0])[0].block_until_ready()
print("Finished JIT compiling")
start = time.time()
for t in theta_ripple:
waveform(t)[0].block_until_ready()
end = time.time()
print("Ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms))
func = vmap(waveform)
func(theta_ripple)[0].block_until_ready()
start = time.time()
hp_batch = func(theta_ripple)[0].block_until_ready()
end = time.time()
print(
"Vmapped ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms)
)
print(hp_batch.shape)
My output gives:
> (1000, 5)
> JIT compiling
> Finished JIT compiling
> Ripple waveform call takes: 0.189530 ms
> Vmapped ripple waveform call takes: 0.071345 ms
> (1000, 40960)
Note I just did this directly on my laptop on the CPU so the effect will be further magnified when using a GPU. Also, I had to remove the reshapes which now incorrectly add dimensions once you use vmap. The resulting array should be shape (batch_dimension, time_grid).
It does indeed look like vmap is faster than running through the parameters in a for loop. The way I tested it was to send all the parameters into the function and call reshape inside of it; this gave results that were faster than removing the reshape and running vmap. I'm not sure why this is happening. The jitted versions (with and without reshape) give identical results, so I think it should be fine to leave it this way, although it does warrant some investigation to see why vmap is performing worse than reshaping.
I think this overall makes sense, once you add the jit and your manual reshaping I think this is basically manually vectorizing the function and so it should perform similarly to vmap + jit. Overall though, it's not good practice in Jax to add this kind of manual reshaping when you can instead use vmap :) vmapping doesn't do the jit for you, so this is required to make it fast!
This PR adds the
SineGaussian
waveform in ripple, along with a detailed python notebook showing the mismatch between the LALInference and Jax implementations.