1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 40 forks source link

operating with apex? #3

Open tongjinle123 opened 4 years ago

tongjinle123 commented 4 years ago

I am try to use this implementation with apex half precision training, but it can't. showing that it need float rather that half:


File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 41, in training_step joint_out, rnnt_loss = self.forward(feature, feature_length, target, target_length, cal_rnnt_loss=True) File "/opt/conda/lib/python3.7/site-packages/apex/amp/_initialize.py", line 197, in new_fwd **applier(kwargs, input_caster)) File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 36, in forward joint_out, rnnt_loss = self.transducer.forward(feature, feature_length, target, target_length, cal_rnnt_loss) File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 79, in forward rnn_t_loss = self.cal_transducer_loss(joint, ori_token, feature_length, ori_token_length) File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 108, in cal_transducer_loss log_probs=log_prob, labels=target.int(), frames_lengths=frame_length.int(), labels_lengths=target_length.int(), reduction='mean') File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 80, in rnnt_loss costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank) File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 16, in forward blank=blank, RuntimeError: xs must be a Float tensor (rnnt_loss at binding.cpp:42) frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fa72c18c687 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so) frame #1: rnnt_loss(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, int) + 0xf79 (0x7fa707c87389 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so) frame #2: + 0x22ea7 (0x7fa707c9aea7 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so) frame #3: + 0x232ee (0x7fa707c9b2ee in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so) frame #4: + 0x1fd11 (0x7fa707c97d11 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)

frame #10: THPFunction_apply(_object*, _object*) + 0x8d6 (0x7fa7601b9e96 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so) frame #63: __libc_start_main + 0xf0 (0x7fa76fc35830 in /lib/x86_64-linux-gnu/libc.so.6)
1ytic commented 4 years ago

Yes, the loss function implemented only for float values. I have to generalize the implementation for other types. Currently, you can convert logits to float explicitly.