Closed EthanMarx closed 4 months ago
Currently the ModelCheckpoint will fail at tracing the model when using spectrogram inputs since it assumes the input is 1D.
ModelCheckpoint
We can generalize this by inferring whether we are in the 2D or 1D regime by inspecting the pl_module.model:
pl_module.model
def on_train_end(self, trainer, pl_module): ... if isinstance(pl_module.model, SupervisedTimeDomainArchitecture): sample_input = torch.randn(1, datamodule.num_ifos, kernel_size) else: sample_input = torch.randn(1, datamodule.num_ifos, dim1, dim2) ...
Or, we can infer the input size directly from the dataloader:
def on_train_end(self, trainer, pl_module): ... batch = next(trainer.datamodule.train_dataloader()) batch = trainer.datamodule.augment(batch) sample_input = batch[0] ...
Currently the
ModelCheckpoint
will fail at tracing the model when using spectrogram inputs since it assumes the input is 1D.We can generalize this by inferring whether we are in the 2D or 1D regime by inspecting the
pl_module.model
:Or, we can infer the input size directly from the dataloader: