justinchuby / torch-onnx

Prototype of the next torch exporter
MIT License
1 stars 1 forks source link

test_tensor_index_advanced_indexing_consecutive #60

Closed justinchuby closed 1 week ago

justinchuby commented 1 week ago

Obtain model graph for MyModel() with torch.export.export... Obtain model graph for MyModel() with torch.export.export... ✅

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, _lifted_tensor_constant0: "i64[2]", _lifted_tensor_constant1: "i64[2, 2]", arg0_1: "f32[3, 4, 5, 6, 7]"):
            # File: /Users/justinc/Documents/GitHub/torch-onnx/tests/torch_tests/torch_onnx_test.py:2259 in forward, code: :, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None
            lift_fresh_copy: "i64[2]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0);  _lifted_tensor_constant0 = None
            lift_fresh_copy_1: "i64[2, 2]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant1);  _lifted_tensor_constant1 = None

            # File: /Users/justinc/Documents/GitHub/torch-onnx/tests/torch_tests/torch_onnx_test.py:2258 in forward, code: return input[
            slice_1: "f32[3, 4, 5, 6, 7]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807);  arg0_1 = None
            unsqueeze: "f32[3, 4, 5, 1, 6, 7]" = torch.ops.aten.unsqueeze.default(slice_1, 3);  slice_1 = None
            index: "f32[3, 2, 2, 1, 6, 7]" = torch.ops.aten.index.Tensor(unsqueeze, [None, lift_fresh_copy, lift_fresh_copy_1]);  unsqueeze = lift_fresh_copy = lift_fresh_copy_1 = None
            return (index,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant1'), target='_lifted_tensor_constant1', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='index'), target=None)])
Range constraints: {}
justinchuby commented 1 week ago

Handle constant tensor spec: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant1'), target='_lifted_tensor_constant1', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='index'), target=None)])