pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
2.17k stars 357 forks source link

torch.matmul causes conversion to executorch to fail #6286

Open virginia-cangelosi opened 4 weeks ago

virginia-cangelosi commented 4 weeks ago

🐛 Describe the bug

When converting a model to ExecuTorch using aten_dialect = export(model, (dummy_input_text, dummy_input_label, dummy_input_midi, dummy_input_duration_phn, dummy_input_duration_ruled_phn, dummy_input_duration_syb, dummy_input_slur), dynamic_shapes=dynamic_shapes) the following lines of code generate a torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression error. `

def forward(self, query, key, value, pos_emb, mask):
    """Compute 'Scaled Dot Product Attention' with rel. positional encoding.

    Args:
        query (torch.Tensor): Query tensor (#batch, time1, size).
        key (torch.Tensor): Key tensor (#batch, time2, size).
        value (torch.Tensor): Value tensor (#batch, time2, size).
        pos_emb (torch.Tensor): Positional embedding tensor
            (#batch, 2*time1-1, size).
        mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
            (#batch, time1, time2).

    Returns:
        torch.Tensor: Output tensor (#batch, time1, d_model).

    """
    q, k, v = self.forward_qkv(query, key, value)
    q = q.transpose(1, 2)  # (batch, time1, head, d_k)

    n_batch_pos = pos_emb.size(0)
    torch._check(pos_emb.size(0) != -1)
    torch._check(pos_emb.size(2) != -1)
    torch._check(pos_emb.size(1) != -1)
    embed = self.linear_pos(pos_emb)
    p = embed.reshape(n_batch_pos, pos_emb.size(1), self.h, self.d_k)
    p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)

    # (batch, head, time1, d_k)
    q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
    # (batch, head, time1, d_k)
    q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

    # compute attention score
    # first compute matrix a and matrix c
    # as described in https://arxiv.org/abs/1901.02860 Section 3.3
    # (batch, head, time1, time2)
    k = k.transpose(-2, -1)
    matrix_ac = torch.matmul(q_with_bias_u, k)        <<<------ error on this line
    # compute matrix b and matrix d
    # (batch, head, time1, 2*time1-1)
    matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
    matrix_bd = self.rel_shift(matrix_bd)

    scores = (matrix_ac + matrix_bd) / math.sqrt(
        self.d_k
    )  # (batch, head, time1, time2)

    return self.forward_attention(v, scores, mask)

` The full error message is


Traceback (most recent call last):
  File "/Users/******/venv10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1903, in run_node
    return node.target(*args, **kwargs)
  File "/Users/******/venv10/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 264, in _fn
    result = fn(*args, is_out=(out is not None), **kwargs)
  File "/Users/******/venv10/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4261, in matmul
    tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
  File "/Users/*****/venv10/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 414, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/Users/*****/venv10/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
  File "/Users/*****/venv10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5205, in evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(96, 192*Max(0, u10) + 192*Max(0, u11) + 192*Max(0, u12) + 192*Max(0, u13) + 192*Max(0, u14) + 192*Max(0, u15) + 192*Max(0, u16) + 192*Max(0, u17) + 192*Max(0, u18) + 192*Max(0, u19) + 192*Max(0, u2) + 192*Max(0, u20) + 192*Max(0, u21) + 192*Max(0, u22) + 192*Max(0, u23) + 192*Max(0, u24) + 192*Max(0, u25) + 192*Max(0, u26) + 192*Max(0, u27) + 192*Max(0, u28) + 192*Max(0, u29) + 192*Max(0, u3) + 192*Max(0, u30) + 192*Max(0, u31) + 192*Max(0, u32) + 192*Max(0, u33) + 192*Max(0, u34) + 192*Max(0, u35) + 192*Max(0, u36) + 192*Max(0, u37) + 192*Max(0, u38) + 192*Max(0, u39) + 192*Max(0, u4) + 192*Max(0, u40) + 192*Max(0, u41) + 192*Max(0, u42) + 192*Max(0, u43) + 192*Max(0, u44) + 192*Max(0, u45) + 192*Max(0, u46) + 192*Max(0, u47) + 192*Max(0, u48) + 192*Max(0, u49) + 192*Max(0, u5) + 192*Max(0, u50) + 192*Max(0, u51) + 192*Max(0, u52) + 192*Max(0, u53) + 192*Max(0, u54) + 192*Max(0, u55) + 192*Max(0, u56) + 192*Max(0, u57) + 192*Max(0, u58) + 192*Max(0, u59) + 192*Max(0, u6) + 192*Max(0, u60) + 192*Max(0, u61) + 192*Max(0, u62) + 192*Max(0, u63) + 192*Max(0, u64) + 192*Max(0, u65) + 192*Max(0, u66) + 192*Max(0, u67) + 192*Max(0, u68) + 192*Max(0, u69) + 192*Max(0, u7) + 192*Max(0, u70) + 192*Max(0, u71) + 192*Max(0, u72) + 192*Max(0, u73) + 192*Max(0, u74) + 192*Max(0, u75) + 192*Max(0, u8) + 192*Max(0, u9)) (unhinted: Ne(96, 192*Max(0, u10) + 192*Max(0, u11) + 192*Max(0, u12) + 192*Max(0, u13) + 192*Max(0, u14) + 192*Max(0, u15) + 192*Max(0, u16) + 192*Max(0, u17) + 192*Max(0, u18) + 192*Max(0, u19) + 192*Max(0, u2) + 192*Max(0, u20) + 192*Max(0, u21) + 192*Max(0, u22) + 192*Max(0, u23) + 192*Max(0, u24) + 192*Max(0, u25) + 192*Max(0, u26) + 192*Max(0, u27) + 192*Max(0, u28) + 192*Max(0, u29) + 192*Max(0, u3) + 192*Max(0, u30) + 192*Max(0, u31) + 192*Max(0, u32) + 192*Max(0, u33) + 192*Max(0, u34) + 192*Max(0, u35) + 192*Max(0, u36) + 192*Max(0, u37) + 192*Max(0, u38) + 192*Max(0, u39) + 192*Max(0, u4) + 192*Max(0, u40) + 192*Max(0, u41) + 192*Max(0, u42) + 192*Max(0, u43) + 192*Max(0, u44) + 192*Max(0, u45) + 192*Max(0, u46) + 192*Max(0, u47) + 192*Max(0, u48) + 192*Max(0, u49) + 192*Max(0, u5) + 192*Max(0, u50) + 192*Max(0, u51) + 192*Max(0, u52) + 192*Max(0, u53) + 192*Max(0, u54) + 192*Max(0, u55) + 192*Max(0, u56) + 192*Max(0, u57) + 192*Max(0, u58) + 192*Max(0, u59) + 192*Max(0, u6) + 192*Max(0, u60) + 192*Max(0, u61) + 192*Max(0, u62) + 192*Max(0, u63) + 192*Max(0, u64) + 192*Max(0, u65) + 192*Max(0, u66) + 192*Max(0, u67) + 192*Max(0, u68) + 192*Max(0, u69) + 192*Max(0, u7) + 192*Max(0, u70) + 192*Max(0, u71) + 192*Max(0, u72) + 192*Max(0, u73) + 192*Max(0, u74) + 192*Max(0, u75) + 192*Max(0, u8) + 192*Max(0, u9))).  (Size-like symbols: u32, u9, u51, u2, u7, u34, u27, u70, u16, u3, u12, u69, u72, u21, u58, u22, u56, u63, u24, u17, u64, u28, u44, u52, u73, u71, u59, u55, u38, u29, u49, u31, u67, u39, u10, u37, u40, u68, u25, u53, u48, u74, u75, u30, u60, u62, u14, u23, u26, u66, u20, u46, u6, u5, u54, u42, u47, u4, u41, u36, u33, u18, u15, u61, u65, u11, u57, u35, u45, u50, u13, u8, u19, u43)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
  File "/Users/*****/venv10/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4261, in matmul
    tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u32,u9,u51,u2,u7,u34,u27,u70,u16,u3,u12,u69,u72,u21,u58,u22,u56,u63,u24,u17,u64,u28,u44,u52,u73,u71,u59,u55,u38,u29,u49,u31,u67,u39,u10,u37,u40,u68,u25,u53,u48,u74,u75,u30,u60,u62,u14,u23,u26,u66,u20,u46,u6,u5,u54,u42,u47,u4,u41,u36,u33,u18,u15,u61,u65,u11,u57,u35,u45,u50,u13,u8,u19,u43"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

I have tried adding torch._check and torch._check_is_size but it has made no difference. If it helps the shapes of the tensors q_with_bias_u and k are (1, 2, value dependant, 96) and k the shape (1, 2, 96, the same data dependent value).

Versions

Collecting environment information... PyTorch version: 2.4.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.0.40.1) CMake version: version 3.30.3 Libc version: N/A

Python version: 3.10.15 (main, Sep 7 2024, 00:20:06) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime) Python platform: macOS-14.6.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M2 Pro

Versions of relevant libraries: [pip3] audiolm-pytorch==1.1.4 [pip3] ema-pytorch==0.5.1 [pip3] executorch==0.3.0a0+7d77d78 [pip3] lion-pytorch==0.2.2 [pip3] numpy==1.23.5 [pip3] onnxruntime==1.18.1 [pip3] optree==0.12.1 [pip3] pytorch-wpe==0.0.1 [pip3] torch==2.4.0 [pip3] torch-complex==0.4.4 [pip3] torchaudio==2.4.0 [pip3] torchsr==1.0.4 [pip3] torchtext==0.18.0 [pip3] torchvision==0.19.0 [pip3] vector-quantize-pytorch==1.14.26 [conda] Could not collect

virginia-cangelosi commented 2 days ago

Is there any help for this?