pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.48k stars 641 forks source link

Questions about HuBERT/Wav2Vec2 pre-training #2947

Open arlofaria opened 1 year ago

arlofaria commented 1 year ago

🐛 Describe the bug

[This is not really a bug report, more a request for clarification/discussion...]

I'm training a HuBERT BASE-sized model, first iteration targeting 100 clusters of MFCC, on a custom 1000-hr dataset (i.e. similar size to Librispeech). See the intriguing plots below:

image

Note that the triangles in the Tensorboard plots indicate there were NaN values. I'm perplexed as to why that leads the loss curve and unit prediction accuracies to have a "hiccup" around step 180K, but then seem to recover by step 200K. My hypothesis is that it's related to the special treatment of feature penalty and layer normalization in the BASE-sized models.

As noted in the Wav2Vec2 paper:

For the smaller Librispeech dataset, we regularize the model by applying an L2 penalty to the activations of the final layer of the feature encoder and scale down the gradients for the encoder by a factor of 10. We also use a slightly different encoder architecture where we do not use layer normalization, and instead of normalizing the raw waveform, the output of the first encoder layer is normalized.

Inspecting the HuBERT code, it's worth clarifying that the L2 "feature penalty" is in fact always included in the loss function, and it is scaled up by a hardcoded factor of 10x -- regardless of dataset or model size, and also irrespective of any masking -- although the 10x downscaling of the feature encoder gradients and the specialized layer normalization are only enabled in configurations for BASE-sized models. So I think a perhaps more straightforward interpretation is that the feature penalty is effectively unscaled for BASE models, and 10x upscaled for LARGE/XLARGE models. Am I reading that correctly?

So my first question: what changes might be suggested to avoid the loss "hiccup" that I've observed? Should I try adjusting the feature penalty scale and/or enabling standard layer normalization in the BASE model configuration?

My second concern is the drop in unmasked accuracy, which seems to start declining somewhat prior to the peak learning rate warmed up by step 20K. I suspect this is because the implementation of the HuBERT loss function does not give any weight to the unmasked logits. The HuBERT paper explored weightings of 0.0, 0.5, 1.0 and found that it was generally best to give zero weight to the unmasked loss component, especially when the targets are relatively low quality in terms of phonemic correlation. However, I wonder: might it be worthwhile to consider some small but non-zero weighting, say 0.1, for the unmasked loss, to prevent the under-fitting dip seen in these plots?

I also wonder about the length normalization when combining masked and unmasked losses. It seems that the current TorchAudio implementation will first combine these weighted masked and unmasked losses (which are summed from masked and unmasked logits of different lengths, depending on the masking parameterization), and add the feature penalty (averaged over all frames, irrespective of masking; this is later scaled by the length of masked logits) before later normalizing the overall summation of losses by the length of the masked logits. By contrast, the fairseq implementation would normalize by the sum of lengths of the masked plus unmasked logits (i.e. the full sequence length) if the weight of the unmasked loss is non-zero. Should the TorchAudio implementation be updated to match the fairseq implementation?

Moreover, I wonder: would it be a sensible improvement to instead normalize the masked and unmasked losses by their respective lengths prior to their weighted summation, and to also compute the feature penalty with respect to the weighting of the masked and unmasked losses (e.g unmasked feature frames should not contribute to the penalty if the unmasked weight is zero)? The advantage of this is that the weighting becomes decoupled from the effect of the masking parameterization and thus it's easier to tune this hyperparameter independently.

Versions

I'm using a slightly modified local fork of the main branch. The principal change is to refactor the training_step to (re-)enable automatic_optimization=True in Lightning (specifically for the gradient accumulation functionality, see #2918), rather than having a manual backward step.

nateanl commented 1 year ago

Hi @arlofaria, thanks for sharing these questions, which I think are very inspiring.

To question 1, the advantage of using group normalization instead of layer normalization is the faster training speed, at the same time the gradient can be unstable, like what you observed in the loss curve. I think you can use layer normalization in Base model to stabilize the training. NOTE you need to also enable normalizing the waveform before feeding it to hubert model (see https://github.com/pytorch/audio/pull/2873)

To question 2: I actually tried with 0.001 weight on unmasked frames and it indeed help increase the unmasked accuracy. However, it doesn't seem to be helpful when fine-tuning the pre-trained model on ASR task. In both experiments (0 weight and 0.001 weight) I didn't meet the "hiccup" in the loss curve, but it's worth trying to see if it works in your case.

To question 3: This is true. Since the HuBERT paper mentioned that 0 weight for unmasked frames is optimal, I hardcoded the sample_size to be the number of frames in the masked frames. I can make it more flexible to let the loss function return two sample sizes and make the value to be 0 if the weight is 0.

To question 4: I think the feature penalty loss is to make feature extraction layers more sparse or avoid overflow, so both masked and unmasked frames can be helpful to such purpose. Regarding normalization of masked and unmasked losses by respective lengths, I think it's a good idea. Would you like to run another experiments to see if that is beneficial to the training? If so we can add it to the current recipe. Thanks!

arlofaria commented 1 year ago

Thanks for the helpful replies!

  1. I'll try enabling full layer normalization for the BASE-sized model, hopefully improving training stability while not degrading training speed too much. Thanks for the note about enabling waveform normalization as well, as that would have been quite easy for me to overlook.
  2. My intuition is that an unmasked weight of 0.001 is rather close to zero, so I might start by trying 0.1 instead... but this value will have a somewhat different interpretation due to (4) below...
  3. Sounds good!
  4. I'll run this experiment and let you know how it goes! However, note that this change to length-normalize the masked and unmasked losses before their weighted combination (instead of after) will modify the semantics of these weight hyper-parameters, so it won't be comparable to prior experiments that FAIR or you might have run.

And one more question, which seems like it may be related to layer normalization:

  1. The fairseq Wav2Vec2 configurations disable conv_bias for BASE models, but enable it for LARGE; meanwhile, conv_bias is always disabled for all sizes of HuBERT models. Is there some rationale for that difference?
arlofaria-zoom commented 1 year ago

Just following up on a few of these questions:

1. Regarding the training instability with NaN loss, and the "hiccup" in loss curve I figured this out: it had nothing to do with layer normalization. Instead, I found that using Lightning's Trainer(precision="bf16") instead of Trainer(precision=16) resolved the problem. I'm not sure if this same problem would happen if I hadn't refactored the training step to enable Trainer(automatic_optimization=True) and use Lightning's half-precision AMP instead of the manual backward approach that is implemented on the main branch.

I don't have a clear explanation for the "hiccup" shape in the curves, but I think it might be an artifact of Lightning's logging: perhaps it appropriately zeros the gradients or skips the batches when encountering a NaN loss, but does not correctly update the batch size when logging the accumulated metric?

2. Regarding the drop in unmasked accuracy Curiously, this does seem to be affected by layer normalization: replacing the group_norm with layer_norm (while also implementing waveform normalization) results in curves that monotonically increase the unmasked accuracy. I'm not yet sure whether that results in improved downstream ASR accuracy.

4. Regarding the weighted combination of masked and unmasked losses I've confirmed that using a non-zero weight for the unmasked loss is indeed detrimental for downstream ASR, particularly for BASE-sized models trained in iterations 1 and 2 with relatively poor targets. However, I found that an unmasked weight of 0.1 (where the masked and unmasked losses are each length-normalized before combination) provided a very slight improvement (or perhaps statistically insignificant difference) for a LARGE-sized model at iteration 3.

5. Regarding the bias vectors for the CNN feature extractor I found that having bias vectors was critical for stability when using Trainer(precision=16), but they are not necessary when using Trainer(precision="bf16"). This result, along with the first point above, suggests to me that there might be something sub-optimal about how Lightning implements AMP training with half-precision floats -- but it's kind of a non-issue if you have hardware that supports bfloat16. See also: https://github.com/Lightning-AI/lightning/pull/5359

And lastly another question:

6. Regarding examples/hubert and examples/ssl ... ... is there a plan to maintain both, or is the latter going to replace the former? Also, is there a plan for newer SSL recipes to be added, e.g. Data2Vec2?

nateanl commented 1 year ago

Yeah bf16 is more robust compared to normal 16-bit precision, although it is only supported on Ampere architecture GPU. My guess on why turning off bias in CNN feature extractor is that it may avoid the value overflow in the weights during optimization.

The plan for the examples/ssl recipe is to make it flexible for users to customize each components (loss function, data module, model architecture, etc). There will be new recipe added based on the ssl example. I will start with the Wav2Vec2 recipe with Conformer as the encoder. For Data2Vec I don't have bandwidth to implement it now, but would like to hear opinions from you and other recipe users. If there are enough interests on a specific SSL recipe we can put it in a planned work :)

arlofaria-zoom commented 1 year ago

I would certainly be very interested in a Data2Vec(2) recipe!

One of the major practical drawbacks of the HuBERT approach is the bottleneck of dumping large hidden activations to disk storage before kmeans clustering; the multi-iteration training is also slightly inconvenient. AFAICT, Data2Vec seems to avoid those problems by keeping everything in memory and doing a single training; it seems that Data2Vec2 incorporates some very significant speedups and incremental accuracy improvements also. That's particularly helpful in comparison to HuBERT, where I've had to explore various shortcuts to scale experiments.

I think the main concern is whether the approach is still being refined; if so, it might be best to wait a bit. For example, is there currently a Data2Vec3 being researched?

A Wav2Vec2 recipe would also be nice. A specific use case to consider is the concept of further pre-training, where a well-trained large model is used as an initial checkpoint. For HuBERT, this required a bit of hacking for me to find the original k-means model and also patch in a few missing components from Fairseq models that were missing from their repackaged bundles in TorchAudio.

However, I largely chose to explore HuBERT instead of Wav2Vec2 because of the availability of the recipe in TorchAudio, which seemed a bit easier to follow and perhaps more actively maintained than what's published in Fairseq. Also, I have been a very big fan of its use of Lightning, and would encourage structuring future recipes to continue using that, ideally with the enabled automatic optimization.

nateanl commented 1 year ago

I would certainly be very interested in a Data2Vec(2) recipe!

Data2vec 2.0 is indeed promising as it outperforms both HuBERT and WavLM. Let me check with them to see if there is new iteration.

I will look into the details when I get more bandwidth. Seems the key component is the teacher-student training and loss functions. The model architecture is pretty much the same as Wav2Vec2.