TensorSpeech / TensorFlowASR

:zap: TensorFlowASR: Almost State-of-the-art Automatic Speech Recognition in Tensorflow 2. Supported languages that can use characters or subwords
https://huylenguyen.com/asr
Apache License 2.0
938 stars 245 forks source link

error when training multigpu, train_ga_subword_conformer #85

Closed ghost closed 3 years ago

ghost commented 3 years ago

Environment: google cloud instance, debian10, tf2.3, cuda11, 4x tesla t4 Model: conformer, training with train_ga_subword_conformer.py Config: batch size=4, ga=1, others are the same as given one, training on libritts train-clean-100, dev on libritts test other

Hello, I can train with on single gpu, but if I specify devices to more gpus, then it fails and gives these traces:

train_subword_conformer.py", line 135, in main
    conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
  File "/home/jupyter/.local/lib/python3.7/site-packages/tensorflow_asr/runners/base_runners.py", line 312, in fit
    self.run()
  File "/home/jupyter/.local/lib/python3.7/site-packages/tensorflow_asr/runners/base_runners.py", line 192, in run
    self._train_epoch()
  File "/home/jupyter/.local/lib/python3.7/site-packages/tensorflow_asr/runners/base_runners.py", line 213, in _train_epoch
    raise e
  File "/home/jupyter/.local/lib/python3.7/site-packages/tensorflow_asr/runners/base_runners.py", line 207, in _train_epoch
    self._train_function(train_iterator)  # Run train step
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 846, in _call
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 4 root error(s) found.
  (0) Invalid argument:  Trying to access element 30 in a list with 30 elements.
         [[{{node StatefulPartitionedCall/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/body/_10413/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/TensorArrayV2Read_1/TensorListGetItem}}]]
  (1) Invalid argument:  Trying to access element 30 in a list with 30 elements.
         [[{{node StatefulPartitionedCall/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/body/_10413/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/TensorArrayV2Read_1/TensorListGetItem}}]]
         [[GroupCrossDeviceControlEdges_0/StatefulPartitionedCall/_2777]]
  (2) Invalid argument:  Trying to access element 30 in a list with 30 elements.
         [[{{node StatefulPartitionedCall/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/body/_10413/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/TensorArrayV2Read_1/TensorListGetItem}}]]
         [[GroupCrossDeviceControlEdges_0/StatefulPartitionedCall/Adam/Adam/update_0/Const/_2761]]
  (3) Invalid argument:  Trying to access element 30 in a list with 30 elements.
         [[{{node StatefulPartitionedCall/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/body/_10413/replica_3/conformer/conformer_prediction/conformer_prediction_lstm_0/while/TensorArrayV2Read_1/TensorListGetItem}}]]
         [[GroupCrossDeviceControlEdges_1/StatefulPartitionedCall/Adam/Adam/update_0/Const/_2749]]
0 successful operations.
0 derived errors ignored. [Op:__inference__train_function_272814]

Function call stack:
_train_function -> _train_function -> _train_function -> _train_function

It lookes like the lstm is reading out of bound, can someone give some help? Thanks in advance!

ghost commented 3 years ago

OK I found the problem, I was using gru instead of lstm for transducer, apparently gru is not compatible with rnnt loss on multi gpu :(

nglehuy commented 3 years ago

@DongyaoZhu I forgot to rename the RNN layer so its name is always "lstm" so it's hard to debug :laughing: I see GRU is getting deprecated, new papers always use LSTM since it's better.