YUCHEN005 / MIR-GAN

Code for paper "MIR-GAN: Refining Frame-Level Modality-Invariant Representations with Adversarial Network for Audio-Visual Speech Recognition"
Other
16 stars 1 forks source link

Gradient Explosion in Discriminator During Training #1

Open cyzhung opened 5 months ago

cyzhung commented 5 months ago

I am experiencing repeated gradient overflow issues during the training of a discriminator model using Fairseq. The training starts normally but soon leads to gradient explosions

Environment Details Fairseq Version: 1.0.0a0+f814da2 PyTorch Version: 2.3.0+cu118 The training process starts with normal logging but then runs into gradient overflow problems as shown below:

[2024-04-25 10:01:21,847][fairseq_cli.train][INFO] - Start iterating over samples ... [other logs] ... [2024-04-25 10:11:33,668][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 0.5 ... [additional gradient overflow notifications] ... [2024-04-25 10:12:07,621][fairseq.trainer][INFO] - NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 0.000244140625

2024-04-25 10:38:38 | WARNING | fairseq.nan_detector | NaN detected in output of encoder.w2v_model.modal_discriminator.sigmoid, shape: torch.Size([16, 87, 1]), backward Exception in thread Thread-3:

FloatingPointError: Minimum loss scale reached (0.0001). Your loss is probably exploding. Try lowering the learning rate, using gradient clipping or increasing the batch size.

I have tried various methods to solve this issue, but none have successfully stabilized the training process: 1.Disabling FP16 Training: Switched to training without FP16 to check if precision issues were causing instability. 2.Reducing the Learning Rate (LR): I significantly lowered the learning rate to see if the model would become more stable.

How to fix the problem?

chaufanglin commented 4 days ago

Hi, I also encountered the same error. Have you solved the error? It showed many nan in modal discriminator too: [fairseq.nan_detector][INFO] - gradients: {'encoder.w2v_model.modal_discriminator.proj1.weight': tensor([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], device='cuda:0', dtype=torch.float16), 'encoder.w2v_model.modal_discriminator.proj1.bias': tensor([nan, nan, nan, nan, nan, nan, nan, nan, ...