k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
802 stars 270 forks source link

k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning #1500

Closed yfyeung closed 1 month ago

yfyeung commented 3 months ago

In this PR, we decoupled HuBERT from fairseq, making it independent from the fairseq library while maintaining full equivalence with the original pre-training logic (model architecture, data normalization, masking strategy, loss computation...). We conducted comparisons on the outputs of some layers to ensure this equivalence. Additionally, we support the checkpoints from fairseq (hubert_base_ls960, hubert_large_ll60k, hubert_xtralarge_ll60k). Then, we optimized the pre-train loss, significantly reducing peak memory usage and even slightly enhancing performance. Unfortunately, this improvement rendered the original HuBERT's half-precision unstable. We adopted ScaledAdam as the optimizer and Eden as the scheduler and replaced the Transformer encoder with the Zipformer encoder. This approach further reduced peak memory usage and enhanced performance, maintaining stability in half-precision.

kobenaxie commented 3 months ago

Hi @yfyeung ,

yfyeung commented 3 months ago
  • How to get k-means file to train the zipformer based HuBERT pretrain model ?

For LibriSpeech, we directly use the k-means labels from hubert_base_ls960.

  • Can we use fbank as the model input like w2vbert

Yes, you can replace the ConvFeatureExtractionModel with the Conv2dSubsampling.

kafan1986 commented 3 months ago

@yfyeung What are approximate increase in WER and training time and inference if this K2SSL is used with say Hubert base?

danpovey commented 1 month ago

Guys, I just noticed this, it seems like a great contribution. I'd rather not have these things wait so long... let me merge it now and if we have any changes we want, we can do them later on.

teowenshen commented 1 month ago

Hi there @yfyeung , first of all thank you for creating this SSL recipe!

I tried running your zipformer/ codes, but my model diverged at epoch 33 and pretraining ended with a Grad scale is small error.

Throughout pretraining before the divergence, I noticed my grad scale tended to fluctuate between 0.125 and 2.

Did you face the same issues?

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

My commands. I adapted the batch size to my setup, maintaining the same accum_grad * max_duration * world_size.

# pretraining
python zipformer/pretrain.py \
    --world-size 4 \
    --use-fp16 1 \
    --num-epochs 50 \
    --manifest-dir data/raw \
    --max-duration 350 \
    --accum-grad 2 \
    --exp-dir zipformer/exp2/pretrain

As per your explanation, I used the same 500 k-means labels from simple_kmeans.

yfyeung commented 1 month ago

Hi there @yfyeung , first of all thank you for creating this SSL recipe!

I tried running your zipformer/ codes, but my model diverged at epoch 33 and pretraining ended with a Grad scale is small error.

Throughout pretraining before the divergence, I noticed my grad scale tended to fluctuate between 0.125 and 2.

Did you face the same issues?

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

My commands. I adapted the batch size to my setup, maintaining the same accum_grad * max_duration * world_size.

# pretraining
python zipformer/pretrain.py \
    --world-size 4 \
    --use-fp16 1 \
    --num-epochs 50 \
    --manifest-dir data/raw \
    --max-duration 350 \
    --accum-grad 2 \
    --exp-dir zipformer/exp2/pretrain

As per your explanation, I used the same 500 k-means labels from simple_kmeans.

Hi, hope this message finds you well.

My training command is as follows:

./zipformer/pretrain.py \
  --world-size 8 \
  --num-epochs 291 \
  --start-epoch 1 \
  --use-fp16 1 \
  --exp-dir zipformer/exp_pretrain \
  --full-libri 1 \
  --max-duration 600 \
  --accum-grad 1 \
  --do-normalize 0 \
  --mask-prob 0.8 \
  --dropout-input 0.1 \
  --dropout-features 0.1 \
  --feature-grad-mult 0.1 \
  --untie-final-proj 1 \
  --num-encoder-layers 2,2,3,4,3,2 \
  --feedforward-dim 512,768,1024,1536,1024,768 \
  --encoder-dim 192,256,448,768,448,192 \
  --encoder-unmasked-dim 192,192,256,256,256,192 \
  --base-lr 0.045

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

Regarding your question about toggling the loss reduction to mean instead of sum to stabilize training: the mean reduction is typically used for multi-GPU simulations to ensure uniform scaling, while sum reduction is preferred for larger batch sizes as it helps stabilize the gradient estimate. It’s not a good way to optimize for both large batch sizes and multi-GPU setups simultaneously.

Fine-tuning command is:

./zipformer/finetune.py \
  --world-size 8 \
  --num-epochs 222 \
  --start-epoch 1 \
  --use-fp16 1 \
  --exp-dir zipformer/exp_finetune \
  --pretrained-dir zipformer/exp_pretrain/epoch-291.pt \
  --full-libri 0 \
  --max-duration 600 \
  --accum-grad 1 \
  --do-normalize 0 \
  --mask-prob 0.65 \
  --mask-channel-prob 0.5 \
  --mask-channel-length 64 \
  --feature-grad-mult 0.0 \
  --num-encoder-layers 2,2,3,4,3,2 \
  --feedforward-dim 512,768,1024,1536,1024,768 \
  --encoder-dim 192,256,448,768,448,192 \
  --encoder-unmasked-dim 192,192,256,256,256,192 \
  --base-lr 0.002

Decoding uses greedy search to identify the top K candidates based on two key parameters: --epoch and --avg:

for ((epoch=100; epoch<=222; epoch+=1)); do
  for ((avg=1; avg<=$epoch-1; avg+=1)); do
    ./zipformer/decode.py \
        --epoch $epoch \
        --avg $avg \
        --exp-dir ./zipformer/exp_finetune \
        --do-normalize 0 \
        --max-duration 1000 \
        --decoding-method greedy_search \
        --num-encoder-layers 2,2,3,4,3,2 \
        --feedforward-dim 512,768,1024,1536,1024,768 \
        --encoder-dim 192,256,448,768,448,192 \
        --encoder-unmasked-dim 192,192,256,256,256,192
  done
done

Then use modified beam search on these top K candidates:

epoch=
avg=
./zipformer/decode.py \
      --epoch $epoch \
      --avg $avg \
      --exp-dir ./zipformer/exp_finetune \
      --do-normalize 0 \
      --max-duration 1000 \
      --decoding-method modified_beam_search \
      --beam-size 8 \
      --num-encoder-layers 2,2,3,4,3,2 \
      --feedforward-dim 512,768,1024,1536,1024,768 \
      --encoder-dim 192,256,448,768,448,192 \
      --encoder-unmasked-dim 192,192,256,256,256,192
teowenshen commented 1 month ago

I see! Thanks for the explanation!

Meanwhile, can you share your finetuning and decoding commands as well?

yfyeung commented 1 month ago

I see! Thanks for the explanation!

Meanwhile, can you share your finetuning and decoding commands as well?

Sure, I updated my comment. You can perform pruning in the process of searching the decoding space.

danpovey commented 1 month ago

@teowenshen is there any chance you can run with from your --start-epoch=33 with the --inf-check=True option, assuming pretrain.py supports these options like train.py; and show us the log? If the options are not there we should add them. I want to see where the inf grad is coming from, maybe we can fix it with more info.

danpovey commented 1 month ago

Also, @yfyeung we normally have a README.md and/or RESULTS.md that show typical sequences of training and testing commands, and associated results. Is there any chance of adding those? Is a link to a paper going to come later?

teowenshen commented 1 month ago

I want to see where the inf grad is coming from, maybe we can fix it with more info.

Yes, please find the logs for epoch 33 as attached.

librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt

I couldn't run --print-diagnostics 1 due to this error:

Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
Traceback (most recent call last):
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1380, in <module>
    main()
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1371, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 246, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 202, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 163, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/workspace/icefall/icefall/diagnostics.py", line 248, in print_diagnostics
    eigs, _ = torch.linalg.eigh(stats)
RuntimeError: "linalg_eigh_cuda" not implemented for 'Half'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1276, in run
    diagnostic.print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 517, in print_diagnostics
    self.diagnostics[k].print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 255, in print_diagnostics
    eigs, _ = torch.linalg.eig(stats)
RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.
danpovey commented 1 month ago

for diaagnostics need to disable fp16 and halve batch size.

On Friday, April 12, 2024, Teo Wen Shen @.***> wrote:

I want to see where the inf grad is coming from, maybe we can fix it with more info.

Yes, please find the logs for epoch 33 as attached.

librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt https://github.com/k2-fsa/icefall/files/14959129/librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt

I couldn't run --print-diagnostics 1 due to this error:

Error getting eigenvalues, trying another method. Error getting eigenvalues, trying another method. Error getting eigenvalues, trying another method. Error getting eigenvalues, trying another method. /workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.) eigs, = torch.linalg.eig(stats) /workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.) eigs, = torch.linalg.eig(stats) /workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.) eigs, = torch.linalg.eig(stats) /workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.) eigs, = torch.linalg.eig(stats) Traceback (most recent call last): File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1380, in main() File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1371, in main mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 246, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 202, in start_processes while not context.join(): File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 163, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error: Traceback (most recent call last): File "/workspace/icefall/icefall/diagnostics.py", line 248, in printdiagnostics eigs, = torch.linalg.eigh(stats) RuntimeError: "linalg_eigh_cuda" not implemented for 'Half'

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap fn(i, *args) File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1276, in run diagnostic.print_diagnostics() File "/workspace/icefall/icefall/diagnostics.py", line 517, in print_diagnostics self.diagnostics[k].print_diagnostics() File "/workspace/icefall/icefall/diagnostics.py", line 255, in printdiagnostics eigs, = torch.linalg.eig(stats) RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.

— Reply to this email directly, view it on GitHub https://github.com/k2-fsa/icefall/pull/1500#issuecomment-2051734565, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3CZWT65GEBXVRV5KDY47MONAVCNFSM6AAAAABDN2HKACVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANJRG4ZTINJWGU . You are receiving this because you modified the open/close state.Message ID: @.***>

danpovey commented 1 month ago

The error was unusual, it was an infinity in the forward-pass. This is because you used the wav2vec2 frontend and it doesn't have any balancers or similar code to stop large values appearing. ScaledAdam can make large values appear faster than Adam would, although even with Adam they'll appear eventually unless steps are taken to stop it.

   x = conv(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in _call_impl
    hook_result = hook(self, args, result)
  File "/workspace/icefall/icefall/hooks.py", line 41, in forward_hook
    raise ValueError(
ValueError: The sum of module.feature_extractor.conv_layers.2.0.output is not finite: tensor([[[  -6.5234,   -6.5078,   -6.6094,  ...,   -6.5820,   -6.5469,
            -6.5469],
         [  -0.7900,   -0.6479,   -0.5444,  ...,   -0.9287,   -0.9971,
            -0.9380],
         [  -7.3672,   -8.1250,   -8.5938,  ...,   -7.8672,   -7.9023,
            -7.8047],

Anyway, this PR https://github.com/k2-fsa/icefall/pull/1593 should fix the issue without causing any model incompatibility. I haven't tested it though.

yfyeung commented 1 month ago

Also, @yfyeung we normally have a README.md and/or RESULTS.md that show typical sequences of training and testing commands, and associated results. Is there any chance of adding those? Is a link to a paper going to come later?

Sure, I will add those after the anonymity period ends, including the model checkpoint/tensorboard/pre-training logs/fine-tuning logs/decoding logs, and RESULTS.md. And if things go well, also a link to the paper.