pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.39k stars 228 forks source link

Dynamic Shapes issue #3636

Closed ismaeelbashir03 closed 1 month ago

ismaeelbashir03 commented 1 month ago

I have been playing around with executorch and cant seem to get the dynamic shapes feature working on some toy examples. Whenever I attempt to add dynamic inputs to my model, I cannot seem to make a forward pass in my android app. Originally, I was using the XNNpack delegation but after reading that it does not support dynamic shapes, I removed the XNN partitioner. If anyone can help, I would greatly appreciate it.

compiling to .pte code:

class Add(torch.nn.Module):
    def __init__(self):
        super(Add, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return x + y

    def get_eager_model(self) -> torch.nn.Module:
        return self

    def get_example_inputs(self):
        return (torch.randn(1, 3), torch.randn(1, 3))

    def get_dynamic_shapes(self):
        dim1_x = Dim("Add_dim1_x", min=MIN_DIM, max=MAX_DIM)
        return {"x": {1: dim1_x}, "y": {1: dim1_x}}

def export_to_edge(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    edge_constant_methods: Optional[Dict[str, Any]] = None,
    edge_compile_config=_EDGE_COMPILE_CONFIG,
    verbose=True,
) -> EdgeProgramManager:
    core_aten_ep = _to_core_aten(model, example_inputs, dynamic_shapes, verbose=verbose)
    return _core_aten_to_edge(
        core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
    )

def _to_core_aten(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    verbose=True,
) -> ExportedProgram:
    # post autograd export. eventually this will become .to_core_aten
    if not isinstance(model, torch.fx.GraphModule) and not isinstance(
        model, torch.nn.Module
    ):
        raise ValueError(
            f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
        )
    core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
    if verbose:
        logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
    return core_aten_ep

model = model.eval()
model = torch._export.capture_pre_autograd_graph(model, example_inputs)

edge = export_to_edge(
        model,
        example_inputs,
        dynamic_shapes=dynamic_shapes,
        edge_compile_config=EdgeCompileConfig(
            _check_ir_validity=False if args.quantize else True,
        ),
)

# edge = edge.to_backend(XnnpackPartitioner())
# logging.info(f"Lowered graph:\n{edge.exported_program().graph}")

exec_prog = edge.to_executorch(
        config=ExecutorchBackendConfig(extract_constant_segment=False)
    )

save_pte_program(exec_prog, model_name, args.output_dir)   

I am doing the forward pass on android like this:


mModule =   Module.load(MainActivity.assetFilePath(getApplicationContext(), "add_xnnpack_fp32.pte"));

Tensor inputTensor1 = Tensor.fromBlob(new float[] {1, 2, 3, 4}, new long[] {1, 4});
Tensor inputTensor2 = Tensor.fromBlob(new float[] {2, 3, 4, 4}, new long[] {1, 4});

final long startTime = SystemClock.elapsedRealtime();
Tensor outputTensor = mModule.forward(EValue.from(inputTensor1), EValue.from(inputTensor2))[0].toTensor();
final long inferenceTime = SystemClock.elapsedRealtime() - startTime;
Log.d("AddSegmentation", "inference time (ms): " + inferenceTime);

I get the following error in logcat:


`2024-05-16 18:09:56.930 27561-27561 ziparchive              com.example.executorchdemo           W  Unable to open '/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.dm': No such file or directory
2024-05-16 18:09:56.930 27561-27561 ziparchive              com.example.executorchdemo           W  Unable to open '/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.dm': No such file or directory
2024-05-16 18:09:57.351 27561-27561 nativeloader            com.example.executorchdemo           D  Configuring clns-6 for other apk /data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.apk. target_sdk_version=33, uses_libraries=, library_path=/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/lib/arm64:/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.apk!/lib/arm64-v8a, permitted_path=/data:/mnt/expand:/data/user/0/com.example.executorchdemo
2024-05-16 18:09:57.360 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V  Currently set values for:
2024-05-16 18:09:57.361 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V    angle_gl_driver_selection_pkgs=[]
2024-05-16 18:09:57.361 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V    angle_gl_driver_selection_values=[]
2024-05-16 18:09:57.361 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V  ANGLE GameManagerService for com.example.executorchdemo: false
2024-05-16 18:09:57.367 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V  com.example.executorchdemo is not listed in per-application setting
2024-05-16 18:09:57.367 27561-27561 GraphicsEnvironment     com.example.executorchdemo           V  Neither updatable production driver nor prerelease driver is supported.
2024-05-16 18:09:57.417 27561-27578 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libEGL_emulation.so
2024-05-16 18:09:57.420 27561-27561 Compatibil...geReporter com.example.executorchdemo           D  Compat change id reported: 210923482; UID 10191; state: ENABLED
2024-05-16 18:09:57.425 27561-27578 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libGLESv1_CM_emulation.so
2024-05-16 18:09:57.428 27561-27578 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libGLESv2_emulation.so
2024-05-16 18:09:57.476 27561-27561 Compatibil...geReporter com.example.executorchdemo           D  Compat change id reported: 237531167; UID 10191; state: DISABLED
2024-05-16 18:09:57.477 27561-27561 OpenGLRenderer          com.example.executorchdemo           W  Unknown dataspace 0
2024-05-16 18:09:57.523 27561-27576 OpenGLRenderer          com.example.executorchdemo           W  Failed to choose config with EGL_SWAP_BEHAVIOR_PRESERVED, retrying without...
2024-05-16 18:09:57.523 27561-27576 OpenGLRenderer          com.example.executorchdemo           W  Failed to initialize 101010-2 format, error = EGL_SUCCESS
2024-05-16 18:09:57.537 27561-27576 Gralloc4                com.example.executorchdemo           I  mapper 4.x is not supported
2024-05-16 18:09:57.542 27561-27576 OpenGLRenderer          com.example.executorchdemo           E  Unable to match the desired swap behavior.
2024-05-16 18:09:58.709 27561-27588 libc                    com.example.executorchdemo           W  Access denied finding property "ro.hardware.chipname"
2024-05-16 18:09:58.772 27561-27576 EGL_emulation           com.example.executorchdemo           D  app_time_stats: avg=295.09ms min=41.73ms max=878.91ms count=4
2024-05-16 18:09:58.780 27561-27588 ExecuTorch              com.example.executorchdemo           E  ETensor rank is immutable old: 1 new: 2
2024-05-16 18:09:58.780 27561-27588 ExecuTorch              com.example.executorchdemo           E  Error setting input 0: 0x10
2024-05-16 18:09:58.780 27561-27588 ExecuTorch              com.example.executorchdemo           A  In function execute_method(), assert failed (result.ok()): Execution of method forward failed with status 0x12
2024-05-16 18:09:58.781 27561-27588 libc                    com.example.executorchdemo           A  Fatal signal 6 (SIGABRT), code -1 (SI_QUEUE) in tid 27588 (Thread-2), pid 27561 (.executorchdemo)`

I have tried also tried reducing the size by doing:


Tensor inputTensor1 = Tensor.fromBlob(new float[] {1, 2, 3, 4}, new long[] {4}); 
Tensor inputTensor2 = Tensor.fromBlob(new float[] {2, 3, 4, 4}, new long[] {4});`

but this gives another error on logcat:


2024-05-16 18:11:17.423 27727-27727 ziparchive              com.example.executorchdemo           W  Unable to open '/data/data/com.example.executorchdemo/code_cache/.overlay/base.apk/classes4.dm': No such file or directory
2024-05-16 18:11:17.424 27727-27727 ziparchive              com.example.executorchdemo           W  Unable to open '/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.dm': No such file or directory
2024-05-16 18:11:17.424 27727-27727 ziparchive              com.example.executorchdemo           W  Unable to open '/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.dm': No such file or directory
2024-05-16 18:11:17.657 27727-27727 nativeloader            com.example.executorchdemo           D  Configuring clns-6 for other apk /data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.apk. target_sdk_version=33, uses_libraries=, library_path=/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/lib/arm64:/data/app/~~UpZdupDxYhjygPZA-_97gA==/com.example.executorchdemo-oxdRFDY3QbmShbj-NZRPMw==/base.apk!/lib/arm64-v8a, permitted_path=/data:/mnt/expand:/data/user/0/com.example.executorchdemo
2024-05-16 18:11:17.673 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V  Currently set values for:
2024-05-16 18:11:17.674 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V    angle_gl_driver_selection_pkgs=[]
2024-05-16 18:11:17.674 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V    angle_gl_driver_selection_values=[]
2024-05-16 18:11:17.674 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V  ANGLE GameManagerService for com.example.executorchdemo: false
2024-05-16 18:11:17.674 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V  com.example.executorchdemo is not listed in per-application setting
2024-05-16 18:11:17.675 27727-27727 GraphicsEnvironment     com.example.executorchdemo           V  Neither updatable production driver nor prerelease driver is supported.
2024-05-16 18:11:17.734 27727-27727 Compatibil...geReporter com.example.executorchdemo           D  Compat change id reported: 210923482; UID 10191; state: ENABLED
2024-05-16 18:11:17.739 27727-27745 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libEGL_emulation.so
2024-05-16 18:11:17.742 27727-27745 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libGLESv1_CM_emulation.so
2024-05-16 18:11:17.756 27727-27745 libEGL                  com.example.executorchdemo           D  loaded /vendor/lib64/egl/libGLESv2_emulation.so
2024-05-16 18:11:17.778 27727-27727 Compatibil...geReporter com.example.executorchdemo           D  Compat change id reported: 237531167; UID 10191; state: DISABLED
2024-05-16 18:11:17.784 27727-27727 OpenGLRenderer          com.example.executorchdemo           W  Unknown dataspace 0
2024-05-16 18:11:17.824 27727-27743 OpenGLRenderer          com.example.executorchdemo           W  Failed to choose config with EGL_SWAP_BEHAVIOR_PRESERVED, retrying without...
2024-05-16 18:11:17.825 27727-27743 OpenGLRenderer          com.example.executorchdemo           W  Failed to initialize 101010-2 format, error = EGL_SUCCESS
2024-05-16 18:11:17.839 27727-27743 Gralloc4                com.example.executorchdemo           I  mapper 4.x is not supported
2024-05-16 18:11:17.850 27727-27743 OpenGLRenderer          com.example.executorchdemo           E  Unable to match the desired swap behavior.
2024-05-16 18:11:22.871 27727-27743 EGL_emulation           com.example.executorchdemo           D  app_time_stats: avg=2491.37ms min=154.90ms max=4827.83ms count=2
2024-05-16 18:11:22.933 27727-27754 libc                    com.example.executorchdemo           W  Access denied finding property "ro.hardware.chipname"
2024-05-16 18:11:22.937 27727-27754 ExecuTorch              com.example.executorchdemo           E  Attempted to resize a static tensor to a new shape at dimension 0 old_size: 1 new_size: 4
2024-05-16 18:11:22.939 27727-27754 ExecuTorch              com.example.executorchdemo           E  Error setting input 0: 0x10
2024-05-16 18:11:22.940 27727-27754 ExecuTorch              com.example.executorchdemo           A  In function execute_method(), assert failed (result.ok()): Execution of method forward failed with status 0x12
2024-05-16 18:11:22.940 27727-27754 libc                    com.example.executorchdemo           A  Fatal signal 6 (SIGABRT), code -1 (SI_QUEUE) in tid 27754 (Thread-2), pid 27727 (.executorchdemo)
mcr229 commented 1 month ago

Hi,

Dynamism of tensor rank isn't supported, so changing tensor rank will require recompiling the entire model. However it looks like in one of these cases you are specifying dynamism on the 1st dimension, and should be resizing from 1x3 -> 1x4 (1x3 is the shape that you captured the model with, and 1x4 is the shape of the input you are giving at inference).

The error message's looked flipped as the error message for 1x3 -> 1x4 is giving rank is immutable error log, while the 1x3 -> 4 (your second example) is giving another error.

Do you mind sharing the graph of the model you are exporting? you can add print(edge.exported_program() right before the edge.to_executorch call in you python script above

ismaeelbashir03 commented 1 month ago

Hi, from what i understood, the rank of the tensor inputs given (i.e (1,3)) is 2? so (1,4) should also be 2 right? excuse me if this is wrong, I am still learning.

The graph output was:


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[1, s0]", arg1_1: "f32[1, s0]"):
            # File: /Users/ismaeelbashir/Documents/University/y4/Dissertation/Training-On-Edge-Mobile-Devices/model_generator/models/operations/model.py:28 in forward, code: return x + y
            aten_add_tensor: "f32[1, s0]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            return (aten_add_tensor,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=100, is_bool=False)}
mcr229 commented 1 month ago

yup that's right,

I believe one of the examples you are using a (4,) tensor which has rank one, which I believe is why one of the error messages is: ETensor rank is immutable old: 1 new: 2

As for the second case where you are using (1,4), but the error message is: ttempted to resize a static tensor to a new shape at dimension 0 old_size: 1 new_size: 4

which looks strange because you are reshaping dimension 1 to a new_size of 4. Can you try using the following executorchbackendconfig when exporting the model:

ExecutorchBackendConfig(
            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
        )
JacobSzwejbka commented 1 month ago

Attempted to resize a static tensor to a new shape at dimension 0 old_size: 1 new_size: 4

This suggests that you are passing a size [4] tensor to the runtime not [1,4].

You also probably need to add 'sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass()' to your ExecutorchBackendConfig to have it use the dynamic_shapes information for the upperbound in memory planning instead of the inputs you passed in. Im working on having this become the default theres just some legacy cases that are making it difficult.

mcr229 commented 1 month ago

lol was just about to tag you @JacobSzwejbka

ismaeelbashir03 commented 1 month ago

Thanks for the responses. I tried added in the sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass() to the ExecutorchBackendConfig, like so:

exec_prog = edge.to_executorch(
        config=ExecutorchBackendConfig(
            extract_constant_segment=False,
            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
        )
    )

but I still get the following error when supplying a tensor of shape (1, 4):

2024-05-16 19:24:29.706 28862-28890 ExecuTorch              com.example.executorchdemo           E  ETensor rank is immutable old: 1 new: 2
2024-05-16 19:24:29.707 28862-28890 ExecuTorch              com.example.executorchdemo           E  Error setting input 0: 0x10
2024-05-16 19:24:29.707 28862-28890 ExecuTorch              com.example.executorchdemo           A  In function execute_method(), assert failed (result.ok()): Execution of method forward failed with status 0x12
2024-05-16 19:24:29.707 28862-28890 libc                    com.example.executorchdemo           A  Fatal signal 6 (SIGABRT), code -1 (SI_QUEUE) in tid 28890 (Thread-2), pid 28862 (.executorchdemo)
mcr229 commented 1 month ago

hmm error logs still seems to suggest that we are changing tensor rank, based on the graph we are exporting with input rank two. @JacobSzwejbka any idea here?

JacobSzwejbka commented 1 month ago

but I still get the following error when supplying a tensor of shape (1, 4):

@ismaeelbashir03 are you using the code linked above to export or do you have other local changes now?

JacobSzwejbka commented 1 month ago

Oh one issue is you need to pass dynamic shape info to torch._export.capture_pre_autograd_graph(model, example_inputs)

ismaeelbashir03 commented 1 month ago

The exact code im using to export is below, the helper functions are from the examples/xnnpack directory:

import argparse
import copy
import logging

import torch
from torch.export import export, ExportedProgram, Dim
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, ExecutorchProgramManager
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.sdk import generate_etrecord

from models import MODEL_NAME_TO_MODEL
from models.model_factory import EagerModelFactory
from portable.utils import export_to_edge, save_pte_program
from __init__ import MODEL_NAME_TO_OPTIONS
from quantization.utils import quantize

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)

if __name__ == "__main__":

    # === Argument Parsing === #
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_name",
        required=True,
        help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
    )
    parser.add_argument(
        "-q",
        "--quantize",
        action="store_true",
        required=False,
        default=False,
        help="Produce an 8-bit quantized model",
    )
    parser.add_argument(
        "-d",
        "--delegate",
        action="store_true",
        required=False,
        default=False,
        help="Produce an XNNPACK delegated model",
    )
    parser.add_argument(
        "-r",
        "--etrecord",
        required=False,
        help="Generate and save an ETRecord to the given file location",
    )
    parser.add_argument("-o", "--output_dir", default="output", help="output directory")

    parser.add_argument(
        "-di",
        "--dynamic_input",
        action="store_true",
        required=False,
        default=False,
        help="Produce model with dynamic input shape",
    )

    args = parser.parse_args()

    if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize:
        raise RuntimeError(
            f"Model {args.model_name} is not a valid name. or not quantizable right now, "
            "please contact executorch team if you want to learn why or how to support "
            "quantization for the requested model"
            f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
        )

    if not args.delegate and args.quantize:
        raise RuntimeError(
            "Quantization is only supported with delegate flag. Please enable delegate flag."
        )

    model, example_inputs, dynamic_shapes = EagerModelFactory.create_model(
        *MODEL_NAME_TO_MODEL[args.model_name]
    )

    model = model.eval()
    model = torch._export.capture_pre_autograd_graph(model, example_inputs)

    if args.quantize:
        logging.info("Quantizing Model...")
        model = quantize(model, example_inputs)

    edge = None
    if args.dynamic_input:
        logging.info("Exporting model with dynamic input shape ...")
        edge = export_to_edge(
            model,
            example_inputs,
            dynamic_shapes=dynamic_shapes,
            edge_compile_config=EdgeCompileConfig(
                _check_ir_validity=False if args.quantize else True,
            ),
        )
    else:
        edge = export_to_edge(
            model,
            example_inputs,
            edge_compile_config=EdgeCompileConfig(
                _check_ir_validity=False if args.quantize else True,
            ),
        )

    edge_copy = copy.deepcopy(edge)

    if args.delegate:
        edge = edge.to_backend(XnnpackPartitioner())
    print("===================================")
    print(edge.exported_program())
    print("===================================")

    exec_prog = edge.to_executorch(
        config=ExecutorchBackendConfig(
            extract_constant_segment=False,
            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
        )
    )

    if args.etrecord is not None:
        generate_etrecord(args.etrecord, edge_copy, exec_prog)
        logging.info(f"Saved ETRecord to {args.etrecord}")

    quant_tag = "q8" if args.quantize else "fp32"
    model_name = f"{args.model_name}_xnnpack_{quant_tag}"
    save_pte_program(exec_prog, model_name, args.output_dir)
ismaeelbashir03 commented 1 month ago

Oh one issue is you need to pass dynamic shape info to torch._export.capture_pre_autograd_graph(model, example_inputs)

I changed it to this and still have the issue, am I passing it in correctly?: model = torch._export.capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shapes)

JacobSzwejbka commented 1 month ago

Tried this out locally

    class Add(torch.nn.Module):
        def __init__(self):
            super(Add, self).__init__()

        def forward(self, x: torch.Tensor, y: torch.Tensor):
            return x + y

        def get_eager_model(self) -> torch.nn.Module:
            return self

        def get_example_inputs(self):
            return (torch.randn(1, 3), torch.randn(1, 3))

        def get_dynamic_shapes(self):
            dim1_x = Dim("Add_dim1_x", min=1, max=10)
            return {"x": {1: dim1_x}, "y": {1: dim1_x}}

    model = Add()
    model = model.eval()
    pre_autograd = torch._export.capture_pre_autograd_graph(
        model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
    )
    ep = export(
        pre_autograd,
        model.get_example_inputs(),
        dynamic_shapes=model.get_dynamic_shapes(),
    )
    edge = to_edge(ep)
    exec_prog = edge.to_executorch(
        config=ExecutorchBackendConfig(
            extract_constant_segment=False,
            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
        ),
    )
    pybound_et = _load_for_executorch_from_buffer(exec_prog.buffer)
    print(pybound_et((torch.ones(1, 5), torch.ones(1, 5)))) # [tensor([[2., 2., 2., 2., 2.]])]

and it worked. So it might just be shapes getting messed up somewhere in your flow.

ismaeelbashir03 commented 1 month ago

I see, could this maybe be an issue with how to model is ran in Java?

mcr229 commented 1 month ago

I wonder if the tensor.fromblob is working correctly. Perhaps we need to check what the shapes are after creating? Maybe there is a bug there?

ismaeelbashir03 commented 1 month ago

I managed to fix the issue, but it was really weird. I changed the name of the file being outputted, to something else and it ran fine. I then changed it back to the old name and it broke again. it might be something to do with android studio caching the old file since I gave it the same name?

thanks for the help though, I really appreciate it 👍.