Closed Butterfly-c closed 2 years ago
When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!
When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!
At the beginning of the training (pruned_loss_scaled = 0) the loss trun into nan. After 10000 num_updates, the pruned_loss_scaled was set as 0.1 and the loss turn into inf.
Soryy, something went wrong when I upload the log.
Do you have any sequences that U > T
, I mean the number of tokens in transcript is greater than the number of frames.
Do you have any sequences that
U > T
, I mean the number of tokens in transcript is greater than the number of frames.
The sample rate is 4 depends on 2 maxpooling lalyers. So the tokens U in unlikely to be greater than T.
I put some logs here:
epoch 3 ; loss inf; num updates 16100 ; lr 0.000704907 epoch 3 ; loss 1.13339; num updates 16200 ; lr 0.000702728 epoch 3 ; loss 1.13215; num updates 16300 ; lr 0.000700569 epoch 3 ; loss inf; num updates 16400 ; lr 0.000698043
What iteration did the loss become inf on, and what kind of model were you using?
What iteration did the loss become inf on, and what kind of model were you using?
The loss become inf at epoch 2, where the pruned_loss_scaled is set to 0.1
The ConformerTransducer model is configured as follows: Encoder: 2 vggblock + 12 conformer and + 1 lstmp + 1layrenorm Decoder: 2 lstm + droupout Joiner: is condigured as k2
What iteration did the loss become inf on, and what kind of model were you using?
Other configurations of the joiner is as follows: lm_only_scale = 0.25 am_only_scale = 0 prune_range = 4 simple_loss_scale= 0.5
pruned_loss_scaled = 0 if num_updates <= 10000 pruned_loss_scaled = 0.1 if 10000 < num_updates <= 20000 pruned_loss_scaled = 1 if num_updates > 20000
Can you dump the input of the batches that leads to the inf
loss, so we can use it to debug this issue. Thanks.
@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.
@Butterfly-c Suppose you used pruned loss like this:
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
symbols=symbols,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
logits = joiner(am_pruned, lm_pruned)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
You can dump the bad cases as follows:
if simple_loss - simple_loss != 0:
simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
torch.save(simple_input, "simple_bad_case.pt")
if pruned_loss - pruned_loss != 0:
pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
torch.save(pruned_input, "pruned_bad_case.pt")
@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.
Thanks for your kindly reply! I have decoded one model from epoch 4, and the decoding result is ok. But, I'm still confused with the inf loss. The max_frame is set to 2500 (i.e. 25s ) in my training environment. I'm curious how long the sentence is can be defined as too-long transcripts?
@Butterfly-c Suppose you used pruned loss like this:
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed( lm=decoder_out, am=encoder_out, symbols=symbols, termination_symbol=blank_id, lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, reduction="sum", return_grad=True, ) # ranges : [B, T, prune_range] ranges = fast_rnnt.get_rnnt_prune_ranges( px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=prune_range, ) # am_pruned : [B, T, prune_range, C] # lm_pruned : [B, T, prune_range, C] am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) # logits : [B, T, prune_range, C] logits = joiner(am_pruned, lm_pruned) pruned_loss = fast_rnnt.rnnt_loss_pruned( logits=logits, symbols=symbols, ranges=ranges, termination_symbol=blank_id, boundary=boundary, reduction="sum", )
You can dump the bad cases as follows:
if simple_loss - simple_loss != 0: simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary} torch.save(simple_input, "simple_bad_case.pt") if pruned_loss - pruned_loss != 0: pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary} torch.save(pruned_input, "pruned_bad_case.pt")
Thanks for your suggestion, I'm trying to upload the pruned_bad_case.pt for you to debug the inf issue. It'll take me some time.
We have compared two models trained with the warp-transducer and the fast-rnnt seperately,but the The GPU usage does not decrease significantly.
Intuitively, the training time of the two models is as follows: loss times_per_update warp-transducer 7m40s fast-rnnt 6m40s
The models above are both tained with v100-32G-4gpu * 2 (i.e. 8gpu). Is there any suggestion to accelerate the training?
watch -n 0.5 nvidia-smi
)prune_range
?the The GPU usage does not decrease significantly.
What do you want to express ?
I'm curious how long the sentence is can be defined as too-long transcripts?
If the sentence is broken into BPE tokens, it is "too long" if the number of BPE tokens is larger than the number of acoustic frames (after subsampling) of this sentence.
- What is your vocabulary size?
- What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
- Is your GPU usage over 90% ? (You can get such information with
watch -n 0.5 nvidia-smi
)- What is the value of
prune_range
?the The GPU usage does not decrease significantly.
What do you want to express ?
Some configuration of my environment is as follows:
1、The vocabulary size is 8245,which contains 6726 Chinese characters,1514 bpe subwords and 5 special symbols. 2、The batch size is 5000 frames (i.e. 50s). 3、As "watch -n 0.5 nvidia-smi" is conducted,the peak volatile gpu-util is over 90%, but most time it is between 80% -90% 4、The pruned_range is 4.
As shown in this paper https://arxiv.org/abs/2206.13236 , the peak GPU usage of fast_rnnt is far below warp-transducer ,and the training time has also been greatly reduced. But as the fast_rnnt conducted in our environment,the training time are not reduced as expected. As conducted with the same batch size (50s),the statistics of the training time are as follows: loss times_per_update warp-transducer 7m40s fast-rnnt 6m40s
Finally, I have another question about the training time. As shown in the paper, the training time per batch of optimized transducer is over 4 times than fast_rnnt. But the training time per epoch of optimized transducer is just 2 times than fast_rnnt.
I really appreciate for your reply.
I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.
I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.
Thanks for your reply, which solved my confusion.
@Butterfly-c Suppose you used pruned loss like this:
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed( lm=decoder_out, am=encoder_out, symbols=symbols, termination_symbol=blank_id, lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, reduction="sum", return_grad=True, ) # ranges : [B, T, prune_range] ranges = fast_rnnt.get_rnnt_prune_ranges( px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=prune_range, ) # am_pruned : [B, T, prune_range, C] # lm_pruned : [B, T, prune_range, C] am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) # logits : [B, T, prune_range, C] logits = joiner(am_pruned, lm_pruned) pruned_loss = fast_rnnt.rnnt_loss_pruned( logits=logits, symbols=symbols, ranges=ranges, termination_symbol=blank_id, boundary=boundary, reduction="sum", )
You can dump the bad cases as follows:
if simple_loss - simple_loss != 0: simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary} torch.save(simple_input, "simple_bad_case.pt") if pruned_loss - pruned_loss != 0: pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary} torch.save(pruned_input, "pruned_bad_case.pt")
Based on your suggestion, I saved some bad cases. What's interesting is that most of the 'ranges' are all zero tensors.
For example, when the training sample is a backgound music, the label is only one
Is the training loss will become inf when the input and output are unbalanced(i.e. input is far smaller than output) ? Can you give some explanation?
After I filtering the training data as follows, the inf problem has decreased: 1、 label_len > 2 2、 feat_len//label_len > 30
Due to the network limitations, I will share the pruned_bad_case.pt latter.
For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows: decoder_out [1, 2, 8245] encoder_out [1, 314, 8245]
Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.
For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows: decoder_out [1, 2, 8245] encoder_out [1, 314, 8245]
Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.
Based on 40 pruned_bad_case.pts, all of the bad cases are "all the sequences in one batch have only one symbol". And the sum of 'ranges' are all zero tensors.
OK, Thanks! That's it. I think our code did not handle S==1
properly, will try to fix it.
@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!
@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!
Due to data permissions, I can't share the bad case information until I get permission. The permission is on the way.
Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.
Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.
OK, I will contact you as soon as I get the permission.
After updating the fast-rnnt to the version of "fix_s_range", the "inf" problem has been fixed. Thanks!
After using the fast_rnnt loss in my environment, the trainning loss always failed into nan or inf. The configuration fo my ConformerTransducer enviroment is as follows:
Finally, 6k hours training data are used to train the RNNT model. At the warmup stage (i.e.pruned_loss_scaled = 0 ), the loss always failed into nan,Also when pruned_loss_scaled is set to 0.1 , the loss always failed into inf.
Is there any suggestions to solve this problem?