SamsungLabs / SummaryMixing

This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention for automatic speech recognition (see: https://arxiv.org/abs/2307.07421). The code is ready to be used with the SpeechBrain toolkit).
Other
86 stars 8 forks source link

Valid step generates a RuntimeError #8

Open Craya opened 2 months ago

Craya commented 2 months ago

Dear Team,

I want to compare the ASR results we have reached based on wav2vec2 & whisper architectures, with your SummaryMixing one.

We are performing a custom ASR training, our dataset is composed of 95 000 records for Train, 16 000 records for Val, 17 000 records for Test.

Train was successfully performed with the following parameters (A100 40G GPU):

However, at epoch 1 valid step, we got the following error:

speechbrain.utils.epoch_loop - Going into epoch 1
  0%|          | 0/8660 [00:00<?, ?it/s]/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5109: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.
  warnings.warn(
 44%|████▍     | 3803/8660 [15:03<18:27,  4.39it/s, train_loss=93.3]speechbrain.utils.checkpoints - Saving an empty state_dict for <torch.cuda.amp.grad_scaler.GradScaler object at 0x7fe0401657c0> in /data/outputs/save/CKPT+2024-05-31+14-19-54+00/scaler.ckpt.
 88%|████████▊ | 7586/8660 [30:12<04:18,  4.15it/s, train_loss=80.9]speechbrain.utils.checkpoints - Saving an empty state_dict for <torch.cuda.amp.grad_scaler.GradScaler object at 0x7fe0401657c0> in /data/outputs/save/CKPT+2024-05-31+14-35-03+00/scaler.ckpt.
100%|██████████| 8660/8660 [34:15<00:00,  4.21it/s, train_loss=78.7]
  0%|          | 0/1291 [00:00<?, ?it/s]
speechbrain.core - Exception:
Traceback (most recent call last):
  File "train.py", line 442, in <module>
    asr_brain.fit(
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1556, in fit
    self._fit_valid(valid_set=valid_set, epoch=epoch, enable=enable)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1462, in _fit_valid
    loss = self.evaluate_batch(batch, stage=Stage.VALID)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1345, in evaluate_batch
    out = self.compute_forward(batch, stage=stage)
  File "train.py", line 68, in compute_forward
    enc_out, pred = self.modules.Transformer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/TransformerASR.py", line 381, in forward
    encoder_out, _ = self.encoder(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 480, in forward
    output, attention = enc_layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 277, in forward
    x2 = self._forward_cnn_branch(x2)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 293, in _forward_cnn_branch
    x = self.convolution_branch(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 94, in forward
    x = self.csgu(x)  # (B, T, D//2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/convolution.py", line 99, in forward
    x2 = self.conv(x2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/nnet/CNN.py", line 428, in forward
    x = self._manage_padding(
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/nnet/CNN.py", line 480, in _manage_padding
    x = F.pad(x, padding, mode=self.padding_mode)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 4495, in pad
    return torch._C._nn.pad(input, pad, mode, value)
RuntimeError: Argument #4: Padding size should be less than the corresponding input dimension, but got: padding (15, 15) at dimension 2 of input [12, 1536, 13]

What's wrong?

Thanks for your support.

Craya commented 2 months ago

We've made other attempts, but the results are the same.

Any idea @TParcollet ?

Fabien.

shucongzhang commented 2 months ago

@Craya Hello, thank you for your interesting in this paper. The problem is for very short input sentences, the CNN branch of the Branchformer will have larger padding sizes than then sequences length. Thus, this is an issue of the Branchformer architecture. So, my suggestion is you can try to filter out very short sequences in your dataset if you want to use Branchformer.

There are no such issues for Conformers. Indeed, we have a Conformer SummaryMixing W2V2 and will release the paper and the code soon. Thus, if your project is not super urgent, can you go back and check our Conformer SummaryMixing code when it is released. Also, if you would prefer to implement it by yourself so you can run experiments immediately, please feel free to ask me questions.

I hope this is helpful.

Shucong

Craya commented 2 months ago

Thanks @shucongzhang for your clear answer.

Our dataset is composed of audios between 0,5s and 10s. When you say to filter out "very short sentences", do you have an idea of value for this minimum duration?

We will check your Conformer SummaryMixing W2V2 release for sure, in this git repo as well?

Thanks.

shucongzhang commented 2 months ago

@Craya No worries. I would suggest trying 1s or 2s minimal length. But again I would suggest using Conformers if the short utterances make a large portion of your dataset.

For W2V2 it should also be in this repo. Please let me know if there anything else I can help with.