sanchit-gandhi / seq2seq-speech

Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.
34 stars 6 forks source link

Random `TypeError` with NumPy array with all values -100 #75

Closed versae closed 2 years ago

versae commented 2 years ago

I'm hitting this error message now and then. It does not seem to be affecting training, but I only see it when training on TPU. The same dataset was used in GPU with no errors. Just posting here in case there is something else going on that I am missing.

Step... (75000/759160 | Eval Loss: 0.11205478757619858 | Eval wer: 0.09877239458498131 | Eval cer: 0.02933955305671511 |):  10%|████▏                                     | 4/40 [23:39:27<205:34:14, 20557.06s/it/
data/flax/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:719: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndar
rays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.                                                                          
  tensor = as_tensor(value)                                                                                                                                                                                        
--- Logging error ---                                                                                                                                                                                              
Traceback (most recent call last):                                                                                                                                                                                 
  File "run_flax_speech_recognition_ctc.py", line 1631, in <module>                                                                                                                                                
    main()                                                                                                                                                                                                         
  File "run_flax_speech_recognition_ctc.py", line 1544, in main                                                                                                                                                    
    state, train_metric = p_train_step(state, batch)                                                                                                                                                               
  File "/data/flax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback                                                                                           
    return fun(*args, **kwargs)                                                                                                                                                                                    
  File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2158, in cache_miss                                                                                                                          
    out_tree, out_flat = f_pmapped_(*args, **kwargs)                                                                                                                                                               
  File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2031, in pmap_f                                                                                                                              
    p = _prepare_pmap(                                                                                                                                                                                             
  File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 1969, in _prepare_pmap                                                                                                                       
    _check_arg(arg)                                                                                                                                                                                                
  File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2994, in _check_arg                                                                                                                          
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")                                                                                                                              
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '[[-100 -100]                                                                                                                                    
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]                                                                                                                                                                                                       
 [-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type.                                                                                                                                            

The stack trace below excludes JAX-internal frames.                                                                                                                                                                
The preceding is the original exception that occurred, unmodified.                                                                                                                                                 

--------------------                                                                                                                                                                                

The above exception was the direct cause of the following exception:                                                                                                                                               

Traceback (most recent call last):                                                                                                                                                                                 
  File "run_flax_speech_recognition_ctc.py", line 1544, in main                                                                                                                                                    
    state, train_metric = p_train_step(state, batch) 
TypeError: Argument '[[-100 -100]
 [-100 -100]
 [-100 -100]
 [-100 -100]
 [-100 -100]
 [-100 -100]
 [-100 -100]
 [-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.8/logging/__init__.py", line 1085, in emit
    msg = self.format(record)
  File "/usr/lib/python3.8/logging/__init__.py", line 929, in format
    return fmt.format(record)
  File "/usr/lib/python3.8/logging/__init__.py", line 668, in format
    record.message = record.getMessage()
  File "/usr/lib/python3.8/logging/__init__.py", line 373, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "run_flax_speech_recognition_ctc.py", line 1631, in <module>
    main()
  File "run_flax_speech_recognition_ctc.py", line 1546, in main
    logger.warning("Encountered following error: \n", e)
Message: 'Encountered following error: \n'
Arguments: (TypeError("Argument '[[-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type."),)                                                  
sanchit-gandhi commented 2 years ago

It looks as though there are no valid training labels in the batch (all labels are equal to the padding mask idx and overridden to -100 in the data collator). The fact that this only occurs for this batch and on TPU only suggests it's a JAX bug! I'll try and reproduce by saving the numpy array to disk and forcing it through a jit/pmap

versae commented 2 years ago

I see. It could then be a tokenization issue? I might've use do_lower_case in this training like pointed in https://github.com/sanchit-gandhi/seq2seq-speech/issues/23.

sanchit-gandhi commented 2 years ago

For CTC, you can set the max_labels_length=1024 and this should bypass the error. The error is (likely) occurring as the target sequence is longer than the max_labels_length and is thus being truncated.

Let me know if this doesn't work and we can dig into this further.