ML4GW / aframev2

Detecting binary black hole mergers in LIGO with neural networks
MIT License
6 stars 14 forks source link

Generalize `ModelCheckpoint` jit tracing for frequency and time domain use cases #140

Closed EthanMarx closed 4 months ago

EthanMarx commented 4 months ago

Currently the ModelCheckpoint will fail at tracing the model when using spectrogram inputs since it assumes the input is 1D.

We can generalize this by inferring whether we are in the 2D or 1D regime by inspecting the pl_module.model:

def on_train_end(self, trainer, pl_module):
    ...
    if isinstance(pl_module.model, SupervisedTimeDomainArchitecture):
        sample_input = torch.randn(1, datamodule.num_ifos, kernel_size)
    else:
        sample_input = torch.randn(1, datamodule.num_ifos, dim1, dim2)
     ...

Or, we can infer the input size directly from the dataloader:

def on_train_end(self, trainer, pl_module):
    ...
    batch = next(trainer.datamodule.train_dataloader())
    batch = trainer.datamodule.augment(batch)
    sample_input = batch[0]
    ...