ML4GW / aframe-legacy

Detecting binary black hole mergers in LIGO with neural networks
MIT License
18 stars 17 forks source link

Profile injection steps #17

Closed alecgunny closed 2 years ago

alecgunny commented 2 years ago

Investigate current injection timing breakdown to see what can be done online vs. during dataset generation

alecgunny commented 2 years ago

As referenced in https://github.com/ML4GW/BBHNet/issues/11#issuecomment-1063295567

alecgunny commented 2 years ago

Related to #10

wbenoit26 commented 2 years ago

I'm not sure how well these results will copy over, but the main takeaway is that >90% of the time to generate these signals comes from two lines: getting the time-domain strain for the given set of parameters, and performing the filtering. Given this, I think it makes sense to generate the training data with fixed times. The below is from simulating 2000 signals.



Total time: 47.3069 s
File: simulate.py
Function: generate_gw at line 13

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    13                                           @profile
    14                                           def generate_gw(
    15                                               sample_params,
    16                                               ifo,
    17                                               waveform_generator=None,
    18                                           ):
    19                                               """Generate gravitational-wave events
    20
    21                                               Arguments:
    22                                               - sample_params: dictionary of GW parameters
    23                                               - ifo: interferometer
    24                                               - waveform_generator: bilby.gw.WaveformGenerator with appropriate params
    25                                               """
    26
    27         4      12991.0   3247.8      0.0      sample_params = [
    28         2         27.0     13.5      0.0          dict(zip(sample_params, col)) for col in zip(*sample_params.values())
    29                                               ]
    30         2          6.0      3.0      0.0      n_sample = len(sample_params)
    31
    32         2          4.0      2.0      0.0      if waveform_generator is None:
    33                                                   waveform_generator = bilby.gw.WaveformGenerator(
    34                                                       duration=8,
    35                                                       sampling_frequency=16384,
    36                                                       frequency_domain_source_model=lal_binary_black_hole,
    37                                                       parameter_conversion=convert_to_lal_binary_black_hole_parameters,
    38                                                       waveform_arguments={
    39                                                           "waveform_approximant": "IMRPhenomPv2",
    40                                                           "reference_frequency": 50,
    41                                                           "minimum_frequency": 20,
    42                                                       },
    43                                                   )
    44
    45         2         19.0      9.5      0.0      sample_rate = waveform_generator.sampling_frequency
    46         2         19.0      9.5      0.0      waveform_duration = waveform_generator.duration
    47         2          6.0      3.0      0.0      waveform_size = int(sample_rate * waveform_duration)
    48
    49         2         57.0     28.5      0.0      signals = np.zeros((n_sample, waveform_size))
    50
    51         2     296875.0 148437.5      0.6      ifo = bilby.gw.detector.get_empty_interferometer(ifo)
    52         4       2837.0    709.2      0.0      b, a = sig.butter(
    53         2          4.0      2.0      0.0          N=8,
    54         2          6.0      3.0      0.0          Wn=waveform_generator.waveform_arguments["minimum_frequency"],
    55         2          3.0      1.5      0.0          btype="highpass",
    56         2         19.0      9.5      0.0          fs=waveform_generator.sampling_frequency,
    57                                               )
    58      4002      11829.0      3.0      0.0      for i, p in enumerate(sample_params):
    59
    60                                                   # For less ugly function calls later on
    61      4000      12250.0      3.1      0.0          ra = p["ra"]
    62      4000       7598.0      1.9      0.0          dec = p["dec"]
    63      4000      10613.0      2.7      0.0          geocent_time = p["geocent_time"]
    64      4000       7490.0      1.9      0.0          psi = p["psi"]
    65
    66                                                   # Generate signal in IFO
    67      4000   32093796.0   8023.4     67.8          polarizations = waveform_generator.time_domain_strain(p)
    68      4000      95928.0     24.0      0.2          signal = np.zeros(waveform_size)
    69     12000      47026.0      3.9      0.1          for mode, polarization in polarizations.items():
    70                                                       # Get ifo response
    71      8000    1600375.0    200.0      3.4              response = ifo.antenna_response(ra, dec, geocent_time, psi, mode)
    72      8000   11388709.0   1423.6     24.1              signal += response * sig.filtfilt(b, a, polarization)
    73
    74                                                   # Total shift = shift to trigger time + geometric shift
    75      4000      10765.0      2.7      0.0          dt = waveform_duration / 2.0
    76      4000     299508.0     74.9      0.6          dt += ifo.time_delay_from_geocenter(ra, dec, geocent_time)
    77      4000     546601.0    136.7      1.2          signal = np.roll(signal, int(np.round(dt * sample_rate)))
    78
    79      4000     861554.0    215.4      1.8          signals[i] = signal
    80
    81         2          4.0      2.0      0.0      return signals```
alecgunny commented 2 years ago

Really quick, can you edit that comment to wrap those results in 3 tick marks instead of 1?

But ok yeah in any case it definitely seems like at least for now it makes sense to bake the sky/time parameters into the dataset and leave this as a future avenue of exploration if we can't fit well enough on the dataset we generate.

The one thing that occurs to me is, does the polarization generation depend on the sky/time parameters? If not, it at least seems possible that we could save the polarizations as a dataset (with an entry for each one of the keys in the polarizations dict mapping to an array of shape (num_waveforms, sample_rate * 8)) and then, at data loading time, apply the filter in an array-like fashion (with axis=-1) to each batch of sampled polarizations before sampling sky/time parameters and using them to compute and apply the IFO responses to these filtered polarizations. Though maybe applying the filtering to the 1-second samples would create too many artifacts.

wbenoit26 commented 2 years ago

Fixed!

Yes, I believe that those parameters would affect the polarizations, so we wouldn't be able to manage around that step.

alecgunny commented 2 years ago

Got it sounds good, thanks for looking into this

EthanMarx commented 2 years ago

@wbenoit26 @alecgunny So the polarization is an independent variable itself. How the two polarization modes (h-plus and h-cross) are projected onto the detectors depend on the sky localization parameters, and gpstime.

This step is taking the longest because this is the step that produces the raw waveform (i.e. the step we want to do in bulk before hand (Just got confirmation from Bilby experts).

So, I'm reading this as we generate purely raw waveforms (no times, ra, decs, gpstimes associated) , and do the projection at train time.

wbenoit26 commented 2 years ago

Oh, awesome. I'd thought that because the function took the full set of parameters, it was making use of at least the sky parameters somehow.

The filtering time is still kind of long, but maybe short enough? From the above timings, it looks to be 5-6 ms per waveform. I'll also dig up the email chain where we decided that we needed the filtering step and confirm it needs to happen this way.

Edit: Never mind, we can just do the filtering ahead of time and store those instead.

alecgunny commented 2 years ago

Ok this is all super interesting. Are the sampling mechanisms for time/sky parameters pretty straightforward? I.e. do we have to use Bilby to sample from some specific prior or can we just do some uniform sampling on the relevant ranges?

EthanMarx commented 2 years ago

It should be straightforward. We can use bilby or do it manually.

wbenoit26 commented 2 years ago

The sky parameters are just ra and dec, right? Those are sampled and stored at the same time as the rest of the parameters. And then the geocent time should match up to the background time

EthanMarx commented 2 years ago

To your first point yes. I'm still trying to evaluate the implications of injecting with a gpstime that is not equivalent to the background time used. I think as long as we inject H1 and L1 at the same gpstime regardless of background time used, we should be okay. Erik will have a concrete answer.