Closed ismaeelbashir03 closed 1 month ago
Hi,
Dynamism of tensor rank isn't supported, so changing tensor rank will require recompiling the entire model. However it looks like in one of these cases you are specifying dynamism on the 1st dimension, and should be resizing from 1x3 -> 1x4 (1x3 is the shape that you captured the model with, and 1x4 is the shape of the input you are giving at inference).
The error message's looked flipped as the error message for 1x3 -> 1x4 is giving rank is immutable error log, while the 1x3 -> 4 (your second example) is giving another error.
Do you mind sharing the graph of the model you are exporting? you can add print(edge.exported_program()
right before the edge.to_executorch
call in you python script above
Hi, from what i understood, the rank of the tensor inputs given (i.e (1,3)) is 2? so (1,4) should also be 2 right? excuse me if this is wrong, I am still learning.
The graph output was:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[1, s0]", arg1_1: "f32[1, s0]"):
# File: /Users/ismaeelbashir/Documents/University/y4/Dissertation/Training-On-Edge-Mobile-Devices/model_generator/models/operations/model.py:28 in forward, code: return x + y
aten_add_tensor: "f32[1, s0]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (aten_add_tensor,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=100, is_bool=False)}
yup that's right,
I believe one of the examples you are using a (4,) tensor which has rank one, which I believe is why one of the error messages is:
ETensor rank is immutable old: 1 new: 2
As for the second case where you are using (1,4), but the error message is:
ttempted to resize a static tensor to a new shape at dimension 0 old_size: 1 new_size: 4
which looks strange because you are reshaping dimension 1 to a new_size of 4. Can you try using the following executorchbackendconfig when exporting the model:
ExecutorchBackendConfig(
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Attempted to resize a static tensor to a new shape at dimension 0 old_size: 1 new_size: 4
This suggests that you are passing a size [4] tensor to the runtime not [1,4].
You also probably need to add 'sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass()' to your ExecutorchBackendConfig to have it use the dynamic_shapes information for the upperbound in memory planning instead of the inputs you passed in. Im working on having this become the default theres just some legacy cases that are making it difficult.
lol was just about to tag you @JacobSzwejbka
Thanks for the responses. I tried added in the sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass() to the ExecutorchBackendConfig, like so:
exec_prog = edge.to_executorch(
config=ExecutorchBackendConfig(
extract_constant_segment=False,
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
but I still get the following error when supplying a tensor of shape (1, 4):
2024-05-16 19:24:29.706 28862-28890 ExecuTorch com.example.executorchdemo E ETensor rank is immutable old: 1 new: 2
2024-05-16 19:24:29.707 28862-28890 ExecuTorch com.example.executorchdemo E Error setting input 0: 0x10
2024-05-16 19:24:29.707 28862-28890 ExecuTorch com.example.executorchdemo A In function execute_method(), assert failed (result.ok()): Execution of method forward failed with status 0x12
2024-05-16 19:24:29.707 28862-28890 libc com.example.executorchdemo A Fatal signal 6 (SIGABRT), code -1 (SI_QUEUE) in tid 28890 (Thread-2), pid 28862 (.executorchdemo)
hmm error logs still seems to suggest that we are changing tensor rank, based on the graph we are exporting with input rank two. @JacobSzwejbka any idea here?
but I still get the following error when supplying a tensor of shape (1, 4):
@ismaeelbashir03 are you using the code linked above to export or do you have other local changes now?
Oh one issue is you need to pass dynamic shape info to torch._export.capture_pre_autograd_graph(model, example_inputs)
The exact code im using to export is below, the helper functions are from the examples/xnnpack
directory:
import argparse
import copy
import logging
import torch
from torch.export import export, ExportedProgram, Dim
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, ExecutorchProgramManager
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.sdk import generate_etrecord
from models import MODEL_NAME_TO_MODEL
from models.model_factory import EagerModelFactory
from portable.utils import export_to_edge, save_pte_program
from __init__ import MODEL_NAME_TO_OPTIONS
from quantization.utils import quantize
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
if __name__ == "__main__":
# === Argument Parsing === #
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
)
parser.add_argument(
"-q",
"--quantize",
action="store_true",
required=False,
default=False,
help="Produce an 8-bit quantized model",
)
parser.add_argument(
"-d",
"--delegate",
action="store_true",
required=False,
default=False,
help="Produce an XNNPACK delegated model",
)
parser.add_argument(
"-r",
"--etrecord",
required=False,
help="Generate and save an ETRecord to the given file location",
)
parser.add_argument("-o", "--output_dir", default="output", help="output directory")
parser.add_argument(
"-di",
"--dynamic_input",
action="store_true",
required=False,
default=False,
help="Produce model with dynamic input shape",
)
args = parser.parse_args()
if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize:
raise RuntimeError(
f"Model {args.model_name} is not a valid name. or not quantizable right now, "
"please contact executorch team if you want to learn why or how to support "
"quantization for the requested model"
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
)
if not args.delegate and args.quantize:
raise RuntimeError(
"Quantization is only supported with delegate flag. Please enable delegate flag."
)
model, example_inputs, dynamic_shapes = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[args.model_name]
)
model = model.eval()
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
if args.quantize:
logging.info("Quantizing Model...")
model = quantize(model, example_inputs)
edge = None
if args.dynamic_input:
logging.info("Exporting model with dynamic input shape ...")
edge = export_to_edge(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False if args.quantize else True,
),
)
else:
edge = export_to_edge(
model,
example_inputs,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False if args.quantize else True,
),
)
edge_copy = copy.deepcopy(edge)
if args.delegate:
edge = edge.to_backend(XnnpackPartitioner())
print("===================================")
print(edge.exported_program())
print("===================================")
exec_prog = edge.to_executorch(
config=ExecutorchBackendConfig(
extract_constant_segment=False,
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
if args.etrecord is not None:
generate_etrecord(args.etrecord, edge_copy, exec_prog)
logging.info(f"Saved ETRecord to {args.etrecord}")
quant_tag = "q8" if args.quantize else "fp32"
model_name = f"{args.model_name}_xnnpack_{quant_tag}"
save_pte_program(exec_prog, model_name, args.output_dir)
Oh one issue is you need to pass dynamic shape info to torch._export.capture_pre_autograd_graph(model, example_inputs)
I changed it to this and still have the issue, am I passing it in correctly?:
model = torch._export.capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shapes)
Tried this out locally
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y
def get_eager_model(self) -> torch.nn.Module:
return self
def get_example_inputs(self):
return (torch.randn(1, 3), torch.randn(1, 3))
def get_dynamic_shapes(self):
dim1_x = Dim("Add_dim1_x", min=1, max=10)
return {"x": {1: dim1_x}, "y": {1: dim1_x}}
model = Add()
model = model.eval()
pre_autograd = torch._export.capture_pre_autograd_graph(
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
)
ep = export(
pre_autograd,
model.get_example_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
)
edge = to_edge(ep)
exec_prog = edge.to_executorch(
config=ExecutorchBackendConfig(
extract_constant_segment=False,
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
),
)
pybound_et = _load_for_executorch_from_buffer(exec_prog.buffer)
print(pybound_et((torch.ones(1, 5), torch.ones(1, 5)))) # [tensor([[2., 2., 2., 2., 2.]])]
and it worked. So it might just be shapes getting messed up somewhere in your flow.
I see, could this maybe be an issue with how to model is ran in Java?
I wonder if the tensor.fromblob is working correctly. Perhaps we need to check what the shapes are after creating? Maybe there is a bug there?
I managed to fix the issue, but it was really weird. I changed the name of the file being outputted, to something else and it ran fine. I then changed it back to the old name and it broke again. it might be something to do with android studio caching the old file since I gave it the same name?
thanks for the help though, I really appreciate it 👍.
I have been playing around with executorch and cant seem to get the dynamic shapes feature working on some toy examples. Whenever I attempt to add dynamic inputs to my model, I cannot seem to make a forward pass in my android app. Originally, I was using the XNNpack delegation but after reading that it does not support dynamic shapes, I removed the XNN partitioner. If anyone can help, I would greatly appreciate it.
compiling to .pte code:
I am doing the forward pass on android like this:
I get the following error in logcat:
I have tried also tried reducing the size by doing:
but this gives another error on logcat: