Closed nihargupte-ph closed 2 years ago
Deleted saving the batch waveforms that seem to have fixed the issue! Now it will run properly, I think what I was saying earlier about the tensors getting sent to the GPU regardless was wrong since now it works.
I was trying to figure out how to implement the batching outside the GNPE loop but couldn't quite figure it out. Maybe I'm missing some understanding there. If the batching happens outside that would mean multiple GNPE loops per dataset right? Is that better than multiple batches per GNPE loop?
I can also put those dict comps into functions if it makes it more readable.
Running main and memory-hotfix 5_000 samples on the first image 10_000 samples on the second. Merged the get_corrected_sky_position w/ sample_posterior_of_injection
Looks like currently the batching only applies to
model.sample(*x, batch_size=batch_size)
but not the rest of the transform pipeline (ie. the pre and post gnpe transformationsgnpe_transforms_pre
andgnpe_transforms_post
). By batching these, we are able to sample a larger number of total points.Note: there is still a small memory issue remaining. Namely, regardless of batching or not, since all the tensors are stored on the GPU it will run out of memory. The fix to this is sending the sampled tensors to .cpu() and then transferring them back while looping over the batches. I was trying to implement an option for this but it was too clunky. Instead, maybe the memory wrapper (ie sampling with GNPE multiple times) is better. It is slower though since we have to run through the GNPE loop more times.