k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.08k stars 211 forks source link

Torchscript export error for torch.finfo function #1284

Open vinitunni opened 2 months ago

vinitunni commented 2 months ago

Hello, I am using k2.__dev_version__ == '1.24.4.dev20240223+cuda11.7.torch2.0.0' for its pruned transducer loss implementation.

I come across a jit error while trying to create its torchscript model due to the torch.finfo fucntion. The trace is as below

2024-04-30 18:29:34,604 | ERROR:Error when creating torchscript model                                                                                                                          
Traceback (most recent call last):                                                                                                                                                             
 File "/mount/user/exp1/projectA2_k2/projectA/projectA_common/projectA/utils/file_utils.py", line 349, in export_to_torchscript                                                                  
 scripted_model = torch.jit.script(model)                                                                                                                                                   
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script                                                                  
 return torch.jit._recursive.create_script_module(                                                                                                                                          
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module                                                  
 return create_script_module_impl(nn_module, concrete_type, stubs_fn)                                                                                                                       
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl                                             
 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)                                                                                                      
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs                              
 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn                                                        
 return torch.jit.script(fn, _rcb=rcb)                                                                                                                                                      
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script                                                                  
 fn = torch._C._jit_script_compile(                                                         
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
 return torch.jit.script(fn, _rcb=rcb)                                                                                                                                                      
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script                                                                  
 fn = torch._C._jit_script_compile(                                                                                                                                                         
RuntimeError: 
Unknown builtin op: aten::finfo.                                                                                                                                                               
Here are some suggestions:                                                                                                                                                                     
 aten::find                                                                                                                                                                             

The original call is:                                                                                                                                                                          
 File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/k2/rnnt_loss.py", line 1604                                                                                  
 normalizers = (                                                                                                                                                                            
 torch.matmul(lm_probs, am_probs.transpose(1, 2))                                                                                                                                       
 + torch.finfo(lm_probs.dtype).tiny                                                                                                                                                     
 ~~~~~~~~~~~ <--- HERE                                                                                                                                                                
 ).log()                                                                                                                                                                                    

Is there a suggested fix for the same? Would really appreciate some help/guidance regarding this.

csukuangfj commented 2 months ago

Could you use our export.py instead of using your own?

vinitunni commented 2 months ago

Sure, i will try that out report here.