warner-benjamin / fastxtend

Train fastai models faster (and other useful tools)
https://fastxtend.benjaminwarner.dev
MIT License
62 stars 5 forks source link

Spectrogram.__init__() got an unexpected keyword argument 'norm' #25

Closed kevinbird15 closed 4 months ago

kevinbird15 commented 4 months ago

fastxtend version: 0.1.7 torchaudio version: 2.1.2+cpu

When trying to use fastxtend's audio module, I ran into the following stack trace:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[46], line 1
----> 1 Spectrogram()

File /opt/conda/lib/python3.10/site-packages/fastcore/transform.py:43, in _TfmMeta.__call__(cls, *args, **kwargs)
     41     getattr(cls,n).add(f)
     42     return f
---> 43 obj = super().__call__(*args, **kwargs)
     44 # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
     45 # instances of cls, fix it
     46 if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)

File /opt/conda/lib/python3.10/site-packages/fastxtend/audio/data.py:81, in Spectrogram.__init__(self, n_fft, win_length, hop_length, pad, window_fn, power, normalized, wkwargs, center, pad_mode, onesided, norm)
     79 else:
     80     self.multiple = False
---> 81     self.spec = tatfms.Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length,
     82                                    pad=self.pad, window_fn=self.window_fn, power=self.power,
     83                                    normalized=self.normalized, wkwargs=self.wkwargs, center=self.center,
     84                                    pad_mode=self.pad_mode, onesided=self.onesided, norm=self.norm)
     86     self._attrs = {k:v for k,v in self._get_attrs().items()}

TypeError: Spectrogram.__init__() got an unexpected keyword argument 'norm'

It looks like the issue is that tatfms stands for torchaudio.transforms, but I looked back to torchaudio.transform Spectrogram inputs back to 0.7.0 and I don't see where norm has ever been an input to Spectrogram. Assuming that is true, we may just need to remove norm=self.norm.

kevinbird15 commented 4 months ago

Here is the updated class that seems to work (at least doesn't stacktrace). I ended up removing norm entirely from the signature since it is only used in those few norm=self.norm spots.

class Spectrogram(DisplayedTransform):
    "Convert a `TensorAudio` into one or more `TensorSpec`"
    order = 75
    def __init__(self,
        n_fft:Listified[int]=1024,
        win_length:Listified[int]|None=None,
        hop_length:Listified[int]|None=None,
        pad:Listified[int]=0,
        window_fn:Listified[Callable[..., Tensor]]=torch.hann_window,
        power:Listified[float]=2.,
        normalized:Listified[bool]=False,
        wkwargs:Listified[dict]|None=None,
        center:Listified[bool]=True,
        pad_mode:Listified[str]="reflect",
        onesided:Listified[bool]=True,
    ):
        super().__init__()
        listify_store_attr()
        attrs = {k:v for k,v in getattr(self,'__stored_args__',{}).items() if k not in ['size', 'mode']}
        # self.resize = size is not None
        if is_listy(self.n_fft):
            self.specs, self._attrs = nn.ModuleList(), []
            self.len, self.multiple = len(self.n_fft), True
            for i in range(self.len):
                self.specs.append(tatfms.Spectrogram(n_fft=self.n_fft[i], win_length=self.win_length[i],
                                                     hop_length=self.hop_length[i], pad=self.pad[i],
                                                     window_fn=self.window_fn[i], power=self.power[i],
                                                     normalized=self.normalized[i], wkwargs=self.wkwargs[i],
                                                     center=self.center[i], pad_mode=self.pad_mode[i],
                                                     onesided=self.onesided[i]))

                self._attrs.append({k:v[i] for k,v in self._get_attrs().items()})
        else:
            self.multiple = False
            self.spec = tatfms.Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length,
                                           pad=self.pad, window_fn=self.window_fn, power=self.power,
                                           normalized=self.normalized, wkwargs=self.wkwargs, center=self.center,
                                           pad_mode=self.pad_mode, onesided=self.onesided)

            self._attrs = {k:v for k,v in self._get_attrs().items()}

    def encodes(self, x:TensorAudio):
        if self.multiple:
            specs = []
            for i in range(self.len):
                specs.append(TensorSpec.create(self.specs[i](x), settings=self._attrs[i]))
            return tuple(specs)
        else:
            return TensorSpec.create(self.spec(x), settings=self._attrs)

    def to(self, *args, **kwargs):
        device, *_ = torch._C._nn._parse_to(*args, **kwargs)
        if self.multiple:
            self.specs.to(device)
        else:
            self.spec.to(device)

    def _get_attrs(self):
        return {k:v for k,v in getattr(self,'__dict__',{}).items() if k in getattr(self,'__stored_args__',{}).keys()}
warner-benjamin commented 4 months ago

Looks like it was something from MelSpectrogram that got added to Spectrogram. Could you check that #26 resolves this issue?

The audio section is the least tested part of fastxtend (both in projects and library tests). At the moment I don't have time or interest in improving that.

Rest of the library has good test coverage, though.

kevinbird15 commented 4 months ago

I just tested and it looks like #26 does fix the issue above. I am going to try using fastxtend for this kaggle competition, so I may run into some more issues and I will do my best to resolve them before I bring them up :)