Hello, I am using k2.__dev_version__ == '1.24.4.dev20240223+cuda11.7.torch2.0.0' for its pruned transducer loss implementation.
I come across a jit error while trying to create its torchscript model due to the torch.finfo fucntion. The trace is as below
2024-04-30 18:29:34,604 | ERROR:Error when creating torchscript model
Traceback (most recent call last):
File "/mount/user/exp1/projectA2_k2/projectA/projectA_common/projectA/utils/file_utils.py", line 349, in export_to_torchscript
scripted_model = torch.jit.script(model)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
return torch.jit._recursive.create_script_module(
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
fn = torch._C._jit_script_compile(
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
fn = torch._C._jit_script_compile(
RuntimeError:
Unknown builtin op: aten::finfo.
Here are some suggestions:
aten::find
The original call is:
File "/data/user/miniconda3/envs/XYZ4_feature/lib/python3.10/site-packages/k2/rnnt_loss.py", line 1604
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(lm_probs.dtype).tiny
~~~~~~~~~~~ <--- HERE
).log()
Is there a suggested fix for the same? Would really appreciate some help/guidance regarding this.
Hello, I am using k2.__dev_version__ == '1.24.4.dev20240223+cuda11.7.torch2.0.0' for its pruned transducer loss implementation.
I come across a jit error while trying to create its torchscript model due to the torch.finfo fucntion. The trace is as below
Is there a suggested fix for the same? Would really appreciate some help/guidance regarding this.