jenellefeather / tfcochleagram

tensorflow integration with mcdermottlab/pycochleagram
MIT License
18 stars 5 forks source link

Model works only with 2-second-long audio? #2

Open T4phage76 opened 1 week ago

T4phage76 commented 1 week ago

Description

Hey, I'm using this model to process some of my customized audio samples. I had an issue (see below) when feeding this model with audio samples longer than 2 seconds (Fs = 20 kHz). If I manually select 40000 (2 sec * 20000 Hz) data points from my sample audio by sample_audio[:, 40000] and run with the 40000-data-point-long trunk, this model works. I'm wondering if this model allows only 2-sec-long audio clips? Is there any way I can use an audio clip of an arbitrary length? Thanks! The error report and my code are listed below.

The error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-44-4c3400905f22>](https://localhost:8080/#) in <cell line: 1>()
      1 with torch.no_grad():
----> 2         (predictions, rep, layer_returns), orig_image = model(sound_example, with_latent=True)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/content/cochdnn/robustness/attacker.py](https://localhost:8080/#) in forward(self, inp, target, make_adv, with_latent, fake_relu, no_relu, with_image, **attacker_kwargs)
    339             if no_relu and fake_relu:
    340                 raise ValueError("Options 'no_relu' and 'fake_relu' are exclusive")
--> 341             output = self.model(preproc_inp, with_latent=with_latent,
    342                                     fake_relu=fake_relu, no_relu=no_relu)
    343         else:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/content/cochdnn/robustness/audio_models/custom_modules.py](https://localhost:8080/#) in forward(self, input, **kwargs)
    134     def forward(self, input, **kwargs):
    135         for module in self:
--> 136             input = module(input, **kwargs)
    137         return input
    138 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/content/cochdnn/robustness/audio_models/custom_modules.py](https://localhost:8080/#) in forward(self, x, with_latent, fake_relu, no_relu)
     53     def forward(self, x, with_latent=False, fake_relu=False, no_relu=False):
     54         # print(self.full_rep)
---> 55         x, _ = self.full_rep(x, None)
     56         return x
     57 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/content/cochdnn/robustness/audio_functions/audio_transforms.py](https://localhost:8080/#) in forward(***failed resolving arguments***)
    173         if foreground_wav is not None:
    174             foreground_wav = foreground_wav
--> 175             foreground_rep, background_rep = self.rep(foreground_wav, None)
    176             if self.compression is not None:
    177                 foreground_rep, background_rep = self.compression(foreground_rep, None)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/content/cochdnn/robustness/audio_functions/audio_transforms.py](https://localhost:8080/#) in forward(***failed resolving arguments***)
    269 
    270         if foreground_wav is not None:
--> 271             foreground_coch = self.Cochleagram(foreground_wav)
    272         else:
    273             foreground_coch = None

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/chcochleagram/cochleagram.py](https://localhost:8080/#) in forward(self, x, return_latent)
     66 
     67     def forward(self, x, return_latent=False):
---> 68         subbands = self.compute_subbands(x)
     69         envelopes = self.envelope_extraction(subbands)
     70 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/chcochleagram/cochleagram.py](https://localhost:8080/#) in forward(self, x)
    115 
    116         if self.apply_in_fourier:
--> 117             x = self._apply_filt_in_fourier(x)
    118         else:
    119             x = self._apply_filt_in_time(x)

[/usr/local/lib/python3.10/dist-packages/chcochleagram/cochleagram.py](https://localhost:8080/#) in _apply_filt_in_fourier(self, x)
    134                x_fft = ch.fft(x, dim=-1).unsqueeze_(-2) # Add channel dim
    135             x_fft = ch.view_as_real(x_fft)
--> 136         filtered_signal = self._complex_multiplication(x_fft, self.coch_filters)
    137 
    138         return filtered_signal

[/usr/local/lib/python3.10/dist-packages/chcochleagram/cochleagram.py](https://localhost:8080/#) in _complex_multiplication(self, t1, t2)
    146       real1, imag1 = [a.squeeze(-1) for a in ch.split(t1, 1, dim=-1)]
    147       real2, imag2 = [a.squeeze(-1) for a in ch.split(t2, 1, dim=-1)]
--> 148       return ch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)

RuntimeError: The size of tensor a (100001) must match the size of tensor b (20001) at non-singleton dimension 3

My implementation

# Load the model
model_dir = '/content/cochdnn/model_directories/kell2018_word_speaker_audioset'

build_network_spec = importlib.util.spec_from_file_location("build_network",
                        os.path.join(model_dir, 'build_network.py'))
build_network = importlib.util.module_from_spec(build_network_spec)
build_network_spec.loader.exec_module(build_network)

model, ds, all_layers = build_network.main(return_metamer_layers=True)

# Load the audio
test_audio, SR = load_audio_wav_resample('/content/stim.wav', DUR_SECS=10, resample_SR=20000)

def preproc_sound_np(sound):
    sound = sound - np.mean(sound)
    sound = sound/np.sqrt(np.mean(sound**2))*0.1
    sound = np.expand_dims(sound, 0)
    sound = torch.from_numpy(sound).float().cuda()
    return sound

sound_example = preproc_sound_np(test_audio)

# Feed audio to the model 
with torch.no_grad():
        (predictions, rep, layer_returns), orig_image = model(sound_example, with_latent=True)

THANK YOU SO MUCH!

jenellefeather commented 1 week ago

Thanks for the issue! I think that this possibly was supposed to go to the cochdnn repo (https://github.com/jenellefeather/cochdnn), and not this tensorflow cochleagram repo?

That said, the models that are in that repository are all trained on 2 second sounds, and the architecture and cochleagram is built for this. There are things you can do if you want activations for longer sounds (remove the end fully connected layer, and modify the cochleagram so that it can take in 10 second long inputs), but the predictions you get out won't be meaningful.

T4phage76 commented 1 week ago

Hi Jenelle,

Thanks for your reply. Yes, this issue should go to the cochdnn repo (https://github.com/jenellefeather/cochdnn). I apologize for opening this in the wrong place. I probably had too many github pages at the same time. Should I open a new issue there and add this link?

Regarding the issue itself, I can make my stimuli under 2 seconds for sure. I probably will try modifying the model to get the activations (indeed what I want is the activations), but I'm not sure if it practically makes sense to use the activation of a 10-sec long audio data with this altered model since it was trained and tested to be behaviorally and neurally predictive only with 2-sec long audio.

Another way that might work for longer audio is to use /cochdnn/robustness/audio_models/kell2018.py, but loaded with weights from /kelletal2018/network/weights.

What do you think? Thank you so much!

jenellefeather commented 1 week ago

All of the models in the cochdnn repo are trained with 2 second sounds, so you will have the same 2 second issue with all of them. The activations for the convolutional layers are probably still fine with longer sounds, minus boundary handling changes from the convolutions, which could slightly change things. you could run some tests and see. Good luck!

T4phage76 commented 1 week ago

Gotcha! Thanks!