Closed versae closed 2 years ago
It looks as though there are no valid training labels in the batch (all labels are equal to the padding mask idx and overridden to -100 in the data collator). The fact that this only occurs for this batch and on TPU only suggests it's a JAX bug! I'll try and reproduce by saving the numpy array to disk and forcing it through a jit
/pmap
I see. It could then be a tokenization issue? I might've use do_lower_case
in this training like pointed in https://github.com/sanchit-gandhi/seq2seq-speech/issues/23.
For CTC, you can set the max_labels_length=1024
and this should bypass the error. The error is (likely) occurring as the target sequence is longer than the max_labels_length
and is thus being truncated.
Let me know if this doesn't work and we can dig into this further.
I'm hitting this error message now and then. It does not seem to be affecting training, but I only see it when training on TPU. The same dataset was used in GPU with no errors. Just posting here in case there is something else going on that I am missing.