dingo-gw / dingo

Dingo: Deep inference for gravitational-wave observations
MIT License
60 stars 20 forks source link

Fixing batching in GNPE inference code #85

Closed nihargupte-ph closed 2 years ago

nihargupte-ph commented 2 years ago

Looks like currently the batching only applies to model.sample(*x, batch_size=batch_size) but not the rest of the transform pipeline (ie. the pre and post gnpe transformations gnpe_transforms_pre and gnpe_transforms_post). By batching these, we are able to sample a larger number of total points.

Note: there is still a small memory issue remaining. Namely, regardless of batching or not, since all the tensors are stored on the GPU it will run out of memory. The fix to this is sending the sampled tensors to .cpu() and then transferring them back while looping over the batches. I was trying to implement an option for this but it was too clunky. Instead, maybe the memory wrapper (ie sampling with GNPE multiple times) is better. It is slower though since we have to run through the GNPE loop more times.

nihargupte-ph commented 2 years ago

Deleted saving the batch waveforms that seem to have fixed the issue! Now it will run properly, I think what I was saying earlier about the tensors getting sent to the GPU regardless was wrong since now it works.

I was trying to figure out how to implement the batching outside the GNPE loop but couldn't quite figure it out. Maybe I'm missing some understanding there. If the batching happens outside that would mean multiple GNPE loops per dataset right? Is that better than multiple batches per GNPE loop?

I can also put those dict comps into functions if it makes it more readable.

nihargupte-ph commented 2 years ago

Comparison_1 Comparison_2

Running main and memory-hotfix 5_000 samples on the first image 10_000 samples on the second. Merged the get_corrected_sky_position w/ sample_posterior_of_injection