Open leigao97 opened 3 weeks ago
@cccclai For QNN Lowering
what does the graph look like for this toy model?
The graph looks like this:
def forward(self, b_state, x):
aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_state, aten_full_default); b_state = aten_full_default = None
aten_linear_default = executorch_exir_dialects_edge__ops_aten_linear_default(x, aten_add_tensor); x = None
return (aten_add_tensor, aten_linear_default)
Is it after torch.export
or to_edge
? Mind sharing the repro script?
It looks like the pre-autograd ATen dialect graph. Here is my script:
import torch
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.backends.qualcomm.utils.utils import (
capture_program,
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
)
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
QcomChipset,
)
from executorch.exir.backend.backend_api import to_backend
class MutableStateModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state", torch.tensor(torch.zeros(1,1)))
def forward(self, x):
self.state.add_(torch.ones(1,1))
return x @ self.state.T
model = MutableStateModule()
inputs = (torch.zeros(1,1),)
edge_prog = capture_program(model, inputs)
qnn_partitioner = QnnPartitioner(
generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650,
backend_options=generate_htp_compiler_spec(use_fp16=True),
debug=False,
saver=False,
shared_buffer=False,
),
)
edge_prog.exported_program = to_backend(edge_prog.exported_program, qnn_partitioner)
From my side, I need to set environment variables like $EXECUTORCH_ROOT
and PYTHONPATH
before running the script.
Here is the reference:
https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html#setting-up-your-developer-environment
Are there any updates on this issue? Thanks!
Not yet - I think it's similar to this issue: https://github.com/pytorch/executorch/issues/4042
The graph looks like this:
def forward(self, b_state, x): aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_state, aten_full_default); b_state = aten_full_default = None aten_linear_default = executorch_exir_dialects_edge__ops_aten_linear_default(x, aten_add_tensor); x = None return (aten_add_tensor, aten_linear_default)
As shown in the graph, the linear op takes the output from the add op as its weight tensor.
In the code below, the weight_node is aten_add_tensor
and it is None.
https://github.com/pytorch/executorch/blob/5584b9e3c865edca239ec5df6346f1d1aabb0276/backends/qualcomm/builders/op_linear.py#L51
And the weight_tensor_wrapper needs to get the value of the weight tensor, which causes the error.
I replaced the get_parameter function with the get_tensor function and it seems to work.
weight_tensor = self.get_tensor(weight_node, node)
Is this an okay bypass? Thank you.
Hmm does it work on runtime? I sort of doubt it...
@leigao97 hey just would like to follow up on this, are you still blocked on the issue?
Yes, the modification above was not correct. The reason why I encounter this issue is that I would like to run int8 weight only quantized model on QNN backend by following this procedure:
I found the root reason is that if we perform any operation on the buffer, then there will be an operator, and the output of that operator doesn't have a parameter value, so the linear_op will fail. In the quantized linear forward function above, the weight buffer is cast to float point, which also can cause this issue.
For now, I am using XNN backend instead.
When I lower this toy example to QNN backend, the linear operator cause an error.
Here is message:
If I replace linear operation
x @ self.state.T
with addtion operationx + self.state.T
, it will work.