TorchDSP / torchsig

TorchSig is an open-source signal processing machine learning toolkit based on the PyTorch data handling pipeline.
MIT License
170 stars 38 forks source link

Fix example jupyter notebook #70

Closed jvincent131 closed 1 year ago

jvincent131 commented 1 year ago

Example notebook 03_example_widebandsig53_dataset give an error AttributeError: module 'torchsig.transforms.transforms' has no attribute 'Spectrogram'

sei-cabidi commented 1 year ago

I ran into this issue as well. Looks like the authors are aware of the problems according to issues #60 and #47. I switched to the v0.2.0 branch and changed the first two cells to the following. Note the removal of ST. before Spectrogram, Normalize, and DescToBBoxSignalDict.

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import torchsig
import torchsig.models
from torchsig.models.spectrogram_models import detr_b0_nano
import torchsig.transforms as ST
from torchsig.transforms.transforms import *
from torchsig.transforms.signal_processing.sp import *
from torchsig.transforms.expert_feature.eft import *
from torchsig.transforms.target_transforms.target_transforms import *
from torchsig.datasets.wideband_sig53 import WidebandSig53

and

# Specify WidebandSig53 Options
root = 'wideband_sig53/'
train = True
impaired = False
fft_size = 512
num_classes = 1

transform = Compose([
    Spectrogram(nperseg=fft_size, noverlap=0, nfft=fft_size, mode='complex'),
    Normalize(norm=np.inf, flatten=True),
])

target_transform = Compose([
    DescToBBoxSignalDict(),
])

# Instantiate the training WidebandSig53 Dataset
wideband_sig53_train = WidebandSig53(
    root=root, 
    train=train, 
    impaired=impaired,
    transform=transform,
    target_transform=target_transform,
    regenerate=False,
    use_signal_data=True,
    gen_batch_size=1,
    use_gpu=True,
)

# Instantiate the validation WidebandSig53 Dataset
train = False
wideband_sig53_val = WidebandSig53(
    root=root, 
    train=train, 
    impaired=impaired,
    transform=transform,
    target_transform=target_transform,
    regenerate=False,
    use_signal_data=True,
    gen_batch_size=1,
    use_gpu=True,
)

# Retrieve a sample and print out information
idx = 0
data, label = wideband_sig53_val[idx]
print("Training Dataset length: {}".format(len(wideband_sig53_train)))
print("Validation Dataset length: {}".format(len(wideband_sig53_val)))
print("Data shape: {}".format(data.shape))
print("Label: {}".format(label))

Hopefully this can help you make some progress until a fix is pushed.