dingo-gw / dingo

Dingo: Deep inference for gravitational-wave observations
MIT License
60 stars 20 forks source link

dingo inference discussion #38

Closed max-dax closed 1 year ago

max-dax commented 2 years ago

My suggestion for the inference script is as follows:

dingo_inference 
  --model /path/to/model.pt
  --GPS_times 1126259462.4 1128678900.4
  --dataset_file /path/to/event_dataset.hdf5

Extensions

With these considerations, the interface would look as follows.

dingo_inference 
  --model /path/to/model.pt 
  --model_gnpe_init /path/to/pose_model .pt
  --events GW150914 GW151012
  --dataset_file /path/to/event_dataset.hdf5
max-dax commented 2 years ago

@stephengreen @jonaswildberger what do you think?

mpuerrer commented 2 years ago

Note that GW event names will only become available after the offline searches have done their work in calculating reliable significance estimates for interesting triggers and the boxes have been opened and checked for data segments. If we wanted to analyze data at an earlier stage before this information becomes available then there won't be any nice GW event names, but only GPS times or graceDB numbers G....... (see e.g. https://gracedb.ligo.org/latest/).

jonaswildberger commented 2 years ago

I think this interface makes perfect sense. I have 2 comments/questions:

mpuerrer commented 2 years ago

I think this interface makes perfect sense. I have 2 comments/questions:

  • Could some inconsistent settings in a given dataset_file be made consistent like interpolating or coarsening PSDs to a different resolution? However, it might not be worth this effort if a re-download is sufficiently fast

I would tend to avoid touching the data and rather downloading the correct one instead since this could introduce small differences.

  • I think specifying either GPS times or an event name should both be possible and the script should infer which one was passed. Could we use a database for the mapping of event names to GPS times instead of maintaining this ourselves?

Good idea. I'm not sure there is such a database. GraceDB does not list GW names.

max-dax commented 2 years ago

Thanks for joining the discussion Michael, I was not aware you had time for it otherwise I would also have tagged you of course.

If we wanted to analyse data at an earlier stage before this information becomes available then there won't be any nice GW event names, but only GPS times or graceDB numbers

Good point. This is the major reason why imo the GPS times should be the 'fundamental' property, i.e. the keys of the dataset. My suggested extension to event names would work on top of that: there could be a user-defined dictionary mapping event names to GPS times in the metadata. The first time the user analyses an event via ... --event GW150914 1126259462.4 the GPS time needs to be specified, but after that, the user can just call ... --event GW150914. This would be convenient, since the event names are easier to remember than the GPS times.

I would tend to avoid touching the data and rather downloading the correct one instead since this could introduce small differences.

Yes, at least for everything that has an impact on the data (such as windowing). There are a few settings that can safely be applied afterwards, like f_min. But I agree that it's probably best to not interpolate PSDs.

I think specifying either GPS times or an event name should both be possible and the script should infer which one was passed. Could we use a database for the mapping of event names to GPS times instead of maintaining this ourselves?

Yes, this would be a nice extension.

mpuerrer commented 2 years ago

Thanks for joining the discussion Michael, I was not aware you had time for it otherwise I would also have tagged you of course.

No worries Max!

If we wanted to analyse data at an earlier stage before this information becomes available then there won't be any nice GW event names, but only GPS times or graceDB numbers

Good point. This is the major reason why imo the GPS times should be the 'fundamental' property, i.e. the keys of the dataset. My suggested extension to event names would work on top of that: there could be a user-defined dictionary mapping event names to GPS times in the metadata. The first time the user analyses an event via ... --event GW150914 1126259462.4 the GPS time needs to be specified, but after that, the user can just call ... --event GW150914. This would be convenient, since the event names are easier to remember than the GPS times.

I agree that GW event names are a nice convenience for us humans to remember some of the events we analyze. However, the way analyses are done within the LVK, if you want to analyze events which have been vetted to some degree by detchar and highlighted by searches the place to look are Superevents on GraceDB (sort of the union of what was found by different searches at the same trigger time).

Therefore, we should at some point add support to run on a specified Superevent number. GraceDB contains a mapping of the G-number to GPS times. (In fact, contrary to what I said earlier, there is also a way to search for GW event names once they have been input -- see https://gracedb.ligo.org/documentation/queries.html -- which would then provide a mapping to GPS times through this service).

It appears that the way to do this is not to use the API of graceDB directly, but through the package https://ligo-gracedb.readthedocs.io/en/latest/. Full access will require us to specify LIGO credentials properly.

max-dax commented 2 years ago
perform_inference(model, GPS_time, event_dataset=None):
    raw_data = get_data(GPS, event_dataset)
    domain_data =  data_to_domain(data, model_data_settings)
    nn_data = to_nn_input(data) # transform
    samples = GNPE_inference(nn_data, model)
    return samples
stephengreen commented 2 years ago

I think we discussed this a fair bit on Friday, but just to summarize, I believe we should think independently about the four components listed above, rather than as the high-level pipeline. Thinking further, I believe this should be three steps:

One use case would be analyzing an entire run, so I see how the DingoDataset for multiple events could be useful. We also want to often analyze a single event as it comes in, so this should also be able to handle that. Another possibility could be to have a single-event dataset, and use a list of these for multiple events. I'm not really sure what is best.

Also, the "process data" step could be implemented as a transform stored in the dataset, in the usual way. Maybe even aspects of the "inference" step could be set up this way as well, but one would have to figure out how to take care of the looping.

Finally, most of the settings can be taken from the model settings, saved in PosteriorModel.metadata. This includes the GNPE settings; however the number of iterations would have to be specified at inference time.

Anyway, I would suggest to work first on the "inference" step, because then we can get up and running on injections. All the data download and saving is a bit of a pain. But up to you.

stephengreen commented 2 years ago

This also connects to the discussion in #14 about where to implement the truncation at f_min. It would be good to not have to consider whether frequency-domain data starts from f=0 or f=f_min when designing transforms and datasets. Also other codes such as Bilby expect frequency arrays to start from 0, so I would suggest we move the low-f truncation to the final step before plugging into the network, as in the old code.

max-dax commented 2 years ago

Thanks for the comments, Stephen. I agree with most of what you said.

  • Inference: I combined the last two steps since preparing NN-formatted data has to alternate with GNPE transforms.

This is not quite right. The GNPE transformations are applied to batched data, and performed on the GPU. This allows you to implement the time shifts as massive matrix multiplications, which is very fast on the GPU. Otherwise I believe inference would be much slower (I have not checked by how much recently, but I remember that I thought about this carefully in the research code).

This step could be broken down as:

  1. Prepare nn input: a. WhitenAndScaleStrain b. RepackageStrainsAndASDS
  2. GNPE inference (if using GNPE; otherwise straight NPE inference): a. sample initial pose with init model b. blur pose to obtain pose proxy c. time shift batch (this works on a batch that is prepared for NPE, including truncation + repackaging) d. sample parameters, condition on shifted batch and pose proxies

Steps b-d are repeated either for a fixed number of iterations, or until a specified convergence criterion is fulfilled (e.g., JSD(proxy_n || proxy_n-1) < threshold).

I therefore think that the splitting this step up makes sense. The logic is that the transformation to_nn_input prepares the data to the format required for standard NPE -- it does everything except for GNPE. The inference step then only adds the GNPE functionality to this. In my opinion this is the clearest separation. A related advantage is that for 2a (sampling initial pose parameters), we don't need any further transformation, but can just take the output of 1.

Let me know what you think.

stephengreen commented 2 years ago

I think this makes sense if we need the speed that comes from doing all the calculations on the GPU. However, the manipulations (in particular, the time translation) will be somewhat intricate to code up, as they rely on the particular form of a sample that is passed to the embedding network. If we change the form of the embedding network, or the data domain, we have to update this. It also means we can't use some of the already-existing transforms that we use for training, so we lose out on maintainability and we end up coding some things twice.

I haven't thought about it a lot, but I'm not entirely sure that we need this speed. For each parameter sample and GNPE step, we are talking about an O(waveform length) calculation on the CPU, plus the GPU/CPU transfer. Would these not be swamped by the passes through the network, which involve O(300) matrix multiplications for each sample and GNPE step, even if the network pass is on the GPU?

I realize that the CPU part is not parallelizable like during training, since the GNPE steps are not independent of each other. But the parameter samples are independent, so depending on how many samples we might want, we could also use a DataLoader to batch the parameter samples, so we could achieve some parallelization. What do you think?

stephengreen commented 2 years ago

Update: I checked @jonaswildberger profiling, which shows ProjectOnDetectors is ~ 5e-4 s / waveform. So if we want 50k samples, this adds up to 25 s / GNPE step, which is obviously too long. I'm not sure how much we can get this down.

Another idea might be to ensure the existing time-translation code can act on batched torch tensors. It might not be too bad to (1) unpack the embedding network representation, (2) apply the time translation using the existing code, and (3) repack for the embedding network. This could all be done staying on the GPU. (PyTorch has support for complex numbers.)

max-dax commented 2 years ago

I will look into this in more detail tomorrow, but for now just the short answer:

Another idea might be to ensure the existing time-translation code can act on batched torch tensors.

Yes, this is how I think it should be done. In fact, I already added some placeholder code to the domain (https://github.com/dingo-gw/dingo-devel/blob/main/dingo/gw/domains.py 211-223) in the very beginning. Essentially, if the domain.time_translate automatically detects whether the data is batched or not, we can just reuse the transformations. We only have to make sure that the proxies are named correctly.

stephengreen commented 2 years ago

Okay makes sense. In fact, I don't see a reason why we can't just use torch tensors throughout the transforms (for both training and inference). They have all the same functionality as numpy arrays, except they can be put on the GPU for inference. Or we can write code that is indifferent to torch tensors / numpy arrays.

I think if we carefully write the transforms (e.g., take slices counting backwards in indices, use ..., etc.) then we don't need to re-write them for batched data.

max-dax commented 2 years ago

I think if we carefully write the transforms (e.g., take slices counting backwards in indices, use ..., etc.) then we don't need to re-write them for batched data.

I think we do. The batched data is not complex, instead real and imaginary parts are separated in channel 0 and 1. So applying time shifts to a batch will look like this:

   # get local phases
    cos_txf = torch.empty((batch_size, num_det, f_bins))
    sin_txf = torch.empty((batch_size, num_det, f_bins))
    for idx_det in range(num_det):
        txf_det = torch.outer(detector_times[:,idx_det], sample_frequencies)
        cos_txf_det = torch.cos(-2*np.pi*txf_det)
        sin_txf_det = torch.sin(-2*np.pi*txf_det)
        cos_txf[:,idx_det,...] = cos_txf_det[...]
        sin_txf[:,idx_det,...] = sin_txf_det[...]

    # apply time shift
    y_out[:,:,0,...] = cos_txf[...] * y_in[:,:,0,:] - sin_txf[...] * y_in[:,:,1,:]
    y_out[:,:,1,...] = sin_txf[...] * y_in[:,:,0,:] + cos_txf[...] * y_in[:,:,1,:]
    y_out[:,:,2,...] = y_in[:,:,2,:]
max-dax commented 2 years ago

Regarding the use of DingoDataset, I feel that in its current form its not a good way to handle the event data. DingoDataset is written for static datasets; it is not flexible regarding the addition of new data, which we would want to have for the event data.

stephengreen commented 2 years ago

Regarding the time-translations, as we discussed in the call, we could either use real and imaginary parts, or the torch complex datatype. Both would work, although it may be slightly simpler to implement the complex datatype. As long as the method is written in an overloaded way it should be good.

We also discussed the DingoDataset on the call. I agree it is really written for static datasets. You could try to extend it to non-static datasets, or we could use one dataset for each event, or we could have a completely new dataset type.

stephengreen commented 2 years ago

Keep in mind also that gwpy has caching, and if run on LSC-connected machines they will have local copies of entire runs. But I agree it could be convenient for us to have a dataset of events.