k2-fsa / fast_rnnt

A torch implementation of a recursion which turns out to be useful for RNN-T.
Other
139 stars 22 forks source link

Train loss is nan or inf #10

Closed Butterfly-c closed 2 years ago

Butterfly-c commented 2 years ago

After using the fast_rnnt loss in my environment, the trainning loss always failed into nan or inf. The configuration fo my ConformerTransducer enviroment is as follows:

Finally, 6k hours training data are used to train the RNNT model. At the warmup stage (i.e.pruned_loss_scaled = 0 ), the loss always failed into nan,Also when pruned_loss_scaled is set to 0.1 , the loss always failed into inf.

Is there any suggestions to solve this problem?

pkufool commented 2 years ago

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

Butterfly-c commented 2 years ago

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

At the beginning of the training (pruned_loss_scaled = 0) the loss trun into nan. After 10000 num_updates, the pruned_loss_scaled was set as 0.1 and the loss turn into inf.

Soryy, something went wrong when I upload the log.

pkufool commented 2 years ago

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

Butterfly-c commented 2 years ago

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

The sample rate is 4 depends on 2 maxpooling lalyers. So the tokens U in unlikely to be greater than T.

I put some logs here:

epoch 3 ; loss inf; num updates 16100 ; lr 0.000704907 epoch 3 ; loss 1.13339; num updates 16200 ; lr 0.000702728 epoch 3 ; loss 1.13215; num updates 16300 ; lr 0.000700569 epoch 3 ; loss inf; num updates 16400 ; lr 0.000698043

danpovey commented 2 years ago

What iteration did the loss become inf on, and what kind of model were you using?

Butterfly-c commented 2 years ago

What iteration did the loss become inf on, and what kind of model were you using?

The loss become inf at epoch 2, where the pruned_loss_scaled is set to 0.1

The ConformerTransducer model is configured as follows: Encoder: 2 vggblock + 12 conformer and + 1 lstmp + 1layrenorm Decoder: 2 lstm + droupout Joiner: is condigured as k2

Butterfly-c commented 2 years ago

What iteration did the loss become inf on, and what kind of model were you using?

Other configurations of the joiner is as follows: lm_only_scale = 0.25 am_only_scale = 0 prune_range = 4 simple_loss_scale= 0.5

pruned_loss_scaled = 0 if num_updates <= 10000 pruned_loss_scaled = 0.1 if 10000 < num_updates <= 20000 pruned_loss_scaled = 1 if num_updates > 20000

pkufool commented 2 years ago

Can you dump the input of the batches that leads to the inf loss, so we can use it to debug this issue. Thanks.

danpovey commented 2 years ago

@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

pkufool commented 2 years ago

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")
Butterfly-c commented 2 years ago

@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

Thanks for your kindly reply! I have decoded one model from epoch 4, and the decoding result is ok. But, I'm still confused with the inf loss. The max_frame is set to 2500 (i.e. 25s ) in my training environment. I'm curious how long the sentence is can be defined as too-long transcripts?

Butterfly-c commented 2 years ago

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Thanks for your suggestion, I'm trying to upload the pruned_bad_case.pt for you to debug the inf issue. It'll take me some time.

Butterfly-c commented 2 years ago

We have compared two models trained with the warp-transducer and the fast-rnnt seperately,but the The GPU usage does not decrease significantly.

Intuitively, the training time of the two models is as follows: loss times_per_update warp-transducer 7m40s fast-rnnt 6m40s

The models above are both tained with v100-32G-4gpu * 2 (i.e. 8gpu). Is there any suggestion to accelerate the training?

csukuangfj commented 2 years ago
  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

csukuangfj commented 2 years ago

I'm curious how long the sentence is can be defined as too-long transcripts?

If the sentence is broken into BPE tokens, it is "too long" if the number of BPE tokens is larger than the number of acoustic frames (after subsampling) of this sentence.

Butterfly-c commented 2 years ago
  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

Some configuration of my environment is as follows:

1、The vocabulary size is 8245,which contains 6726 Chinese characters,1514 bpe subwords and 5 special symbols. 2、The batch size is 5000 frames (i.e. 50s). 3、As "watch -n 0.5 nvidia-smi" is conducted,the peak volatile gpu-util is over 90%, but most time it is between 80% -90% 4、The pruned_range is 4.

As shown in this paper https://arxiv.org/abs/2206.13236 , the peak GPU usage of fast_rnnt is far below warp-transducer ,and the training time has also been greatly reduced. But as the fast_rnnt conducted in our environment,the training time are not reduced as expected. As conducted with the same batch size (50s),the statistics of the training time are as follows: loss times_per_update warp-transducer 7m40s fast-rnnt 6m40s

Finally, I have another question about the training time. As shown in the paper, the training time per batch of optimized transducer is over 4 times than fast_rnnt. But the training time per epoch of optimized transducer is just 2 times than fast_rnnt.

I really appreciate for your reply.

danpovey commented 2 years ago

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

Butterfly-c commented 2 years ago

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

Thanks for your reply, which solved my confusion.

Butterfly-c commented 2 years ago

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Based on your suggestion, I saved some bad cases. What's interesting is that most of the 'ranges' are all zero tensors.

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows: decoder_out [1, 2, 8245] encoder_out [1, 314, 8245]

Is the training loss will become inf when the input and output are unbalanced(i.e. input is far smaller than output) ? Can you give some explanation?

Butterfly-c commented 2 years ago

After I filtering the training data as follows, the inf problem has decreased: 1、 label_len > 2 2、 feat_len//label_len > 30

Butterfly-c commented 2 years ago

Due to the network limitations, I will share the pruned_bad_case.pt latter.

pkufool commented 2 years ago

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows: decoder_out [1, 2, 8245] encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.

Butterfly-c commented 2 years ago

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows: decoder_out [1, 2, 8245] encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.

Based on 40 pruned_bad_case.pts, all of the bad cases are "all the sequences in one batch have only one symbol". And the sum of 'ranges' are all zero tensors.

pkufool commented 2 years ago

OK, Thanks! That's it. I think our code did not handle S==1 properly, will try to fix it.

pkufool commented 2 years ago

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!

Butterfly-c commented 2 years ago

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!

Due to data permissions, I can't share the bad case information until I get permission. The permission is on the way.

pkufool commented 2 years ago

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

Butterfly-c commented 2 years ago

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

OK, I will contact you as soon as I get the permission.

Butterfly-c commented 2 years ago

After updating the fast-rnnt to the version of "fix_s_range", the "inf" problem has been fixed. Thanks!