justinchuby / torch-onnx

Prototype of the next torch exporter
MIT License
1 stars 1 forks source link

JIT #73

Closed justinchuby closed 2 days ago

justinchuby commented 1 week ago

Fallback to JIT when torch.export fails.

TODO: Improve success rate.

justinchuby commented 3 days ago

Need to figure out how to call jit trace correctly:

Traceback (most recent call last):
  File "/home/justinchu/dev/torch-onnx/src/torch_onnx/_core.py", line 868, in export
    jit_model = torch.jit.trace(
                ^^^^^^^^^^^^^^^^
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 601, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 1005, in trace
    traced_func = _trace_impl(
                  ^^^^^^^^^^^^
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 700, in _trace_impl
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/jit/_trace.py", line 1280, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of CausalLMOutputWithPast(loss=None, logits=tensor([[[-1.1278e+01, -7.2935e+00,  2.1226e+00,  ..., -9.0684e+00,
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]
justinchuby commented 3 days ago

Ref https://github.com/justinchuby/pytorch/blob/main/torch/onnx/utils.py#L954