iamjanvijay / rnnt

An implementation of RNN-Transducer loss in TF-2.0.
MIT License
45 stars 9 forks source link

why ‘labels - 1 ’ in compute_rnnt_loss_and_grad_helper #6

Closed alsm168 closed 3 years ago

alsm168 commented 3 years ago

hello,I have a question here:why ‘labels - 1 ’ in compute_rnnt_loss_and_grad_helper? b = tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1))

iamjanvijay commented 3 years ago

Hey @alsm168!

In line 150, I have computed grads of rnnt_loss w.r.t truth log-probabilities. So, grad_truth is of shape (B, T, U). In line 164, I create a scatter_idx to scatter grad_truth to a grid of shape (B, T, U, V), where V is the number of valid output symbols excluding blank symbol. In line 167, I scatter the grad_truth to create grads_truth_scatter. In line 169, I concatenate the grads wrt to blank to grads_truth_scatter, to create a (B, T, U, V+1) matrix. Note that 0-index is reserved for the blank symbol.

Since, values in labels lie in range (1, 2, 3 ... V). It makes sense to have b = tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), which is used in creating scatter_idx.

B -> Batch-size / T -> Input time steps / U -> Output time steps

alsm168 commented 3 years ago

Hi,iamjanvijay thanks for your reply, when training the model,To get batch data, the label index needs to pad to meet same length. if pad with blank index 0, ‘labels - 1’ will get -1 index. Is this a problem? So ,when training, How to make TransducerPrediction input and label? thank you!

iamjanvijay commented 3 years ago

You can go ahead with padding with zero. Syntactically, It won't be a problem because an index of -1 will mean the last element in the concerned dimension. Logically, there won't be a problem too, because I've used zero_masks everywhere for the length of each sequence in a batch.

I'll add a dummy model training code for reference. You can have a look at this issue until then: https://github.com/iamjanvijay/rnnt/issues/3

alsm168 commented 3 years ago

Hi,iamjanvijay I tried the pad with zero, when training,I meet the error as below,so is there a bug? File "D:\LUOSIMING\2019_work_mycode\Speech2Word\frame2word\runners\base_runners.py", line 322, in fit self.run() File "D:\LUOSIMING\2019_work_mycode\Speech2Word\frame2word\runners\base_runners.py", line 202, in run self._train_epoch() File "D:\LUOSIMING\2019_work_mycode\Speech2Word\frame2word\runners\base_runners.py", line 223, in _train_epoch raise e File "D:\LUOSIMING\2019_work_mycode\Speech2Word\frame2word\runners\base_runners.py", line 217, in _train_epoch self._train_function(train_iterator) # Run train step File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in call result = self._call(*args, *kwds) File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 814, in _call results = self._stateful_fn(args, **kwds) File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 2829, in call return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1843, in _filtered_call return self._call_flat( File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1923, in _call_flat return self._build_call_outputs(self._inference_function.call( File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 545, in call outputs = execute.execute( File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[1,0,4] = [1, 0, 4, -1] does not index into shape [2,256,6,368] [[{{node StatefulPartitionedCall/rnnt_loss/ScatterNd_1}}]] [Op:inferencetrain_function_12699]

iamjanvijay commented 3 years ago

I tried passing labels with all values as -1 and it seems to be working fine for me, without throwing any error. Can you please have a look at the sample training loop which I have added: https://github.com/iamjanvijay/rnnt/blob/master/source/sample_train.py

alsm168 commented 3 years ago

ok,I will look the sample_train.py soon. and the blank index can be set to be the vocaulary_size+1 ?

iamjanvijay commented 3 years ago

No, 0 is reserved for the blank index. Will have to update the code for that feature.