Closed justinchuby closed 2 days ago
Need to figure out how to call jit trace correctly:
Traceback (most recent call last):
File "/home/justinchu/dev/torch-onnx/src/torch_onnx/_core.py", line 868, in export
jit_model = torch.jit.trace(
^^^^^^^^^^^^^^^^
File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 601, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 1005, in trace
traced_func = _trace_impl(
^^^^^^^^^^^^
File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 700, in _trace_impl
return trace_module(
^^^^^^^^^^^^^
File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 1280, in trace_module
module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of CausalLMOutputWithPast(loss=None, logits=tensor([[[-1.1278e+01, -7.2935e+00, 2.1226e+00, ..., -9.0684e+00,
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]
Fallback to JIT when torch.export fails.
TODO: Improve success rate.