fastmachinelearning / qonnx

QONNX: Arbitrary-Precision Quantized Neural Networks in ONNX
https://qonnx.readthedocs.io/
Apache License 2.0
121 stars 39 forks source link

inference_cost_matmul: Confusion or Bug regarding the MAC-count of Scaled Dot-Product Attention #60

Open iksnagreb opened 1 year ago

iksnagreb commented 1 year ago

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

Quick summary

While working on our characterization of the transformer data-flow we encountered some discrepancies when validating against the QONNX inference_cost estimations of the MatMul operator within the attention mechanism. We are not entirely sure whether this is indeed a bug on the QONNX side or still some confusion/error on our side. Thus we would like to start a discussion to understand this issue.

Details

Multi-Head Scaled Dot-Product Attention involves two consecutive MatMul operations where both inputs dynamically depend on the model inputs. The heads are independent of each other and typically treated in a way similar to a batch dimension. Our cost model assumes HxTxTxd MAC operations for each of the two MatMuls, i.e. H heads each producing a TxT attention matrix (T is the sequence length) where each element is the result of a d-dimensional dot-product. However, the QONNX analysis function inference_cost_matmul seems to be off by an additional factor of H (i.e. HxHxTxTxd), indicating the heads are not treated like a batch dimension.

My suspicion is further raised by the following lines from the QONNX inference_cost_matmul function:

# exclude common dim (last axis) from one side to avoid duplication
n_macs = np.prod(i_shape[:-1]) * np.prod(w_shape)

Is this actually always the case? At least for the model graph I have attached it seems like the last axis is not the common dimension.

In the following, I provide a minimal working example of a scaled dot-product attention in isolation in PyTorch exporting to an ONNX graph. I have also attached the already preprocessed graph which in particular already includes the InferShapes transform. Note that running the qonnx.util.inference_cost script on the PyTorch ONNX export breaks at the FoldConstants transform due to IndexError which is probably unrelated and should be investigated separately (I have "fixed" it by removing that transformation step for now).

Steps to Reproduce

The following code produces a minimal example of scaled dot-product attention and exports to ONNX.

import torch

# Minimal working example of the Scaled Dot-Product Attention mechanism
class ScaleDotProductAttention(torch.nn.Module):
    # Initializes the module parameters
    def __init__(self, num_heads):
        # Initialize the PyTorch base Module
        super().__init__()
        # Set the number of attention heads
        self.num_heads = num_heads

    # Forward pass computing scaled dot-product attention between q, k and v
    def forward(self, q, k, v):
        # Assume the most simple case of q, k and v all having the same
        # dimensions
        assert q.shape == k.shape == v.shape, \
            "Q, K and V must have the same shape"
        # Embedding dimension must be divisible by number of heads
        assert q.shape[-1] % self.num_heads == 0, \
            f"Dimensions must be divisible by heads ({self.num_heads})"

        # Assume sequence first layout and get the sizes per axis
        s, b, d = q.shape
        # Number of heads and dimension per head
        n_head, d_head = self.num_heads, d // self.num_heads

        # Reshape tensors to treat the heads like batch dimensions
        q = q.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        k = k.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        v = v.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        # Compute the not-yet-normalized attentions matrices for each head.
        #   Note: permute brings batch x heads to front and transposes k
        a = torch.matmul(q.permute(1, 0, 2), k.permute(1, 2, 0))
        # Scale and normalize the attention matrix
        a = torch.softmax(a * (d_head ** -0.5), dim=-1)
        # Apply the attention matrices to the value projection
        #   Note: Second permute brings sequence dimension back to front
        o = torch.matmul(a, v.permute(1, 0, 2)).permute(1, 0, 2)
        # Reshape heads into feature dimension
        o = o.reshape(s, b, n_head, d_head).reshape(s, b, n_head * d_head)

        # Return the scaled dot-product attention output
        return o

# Script entrypoint
if __name__ == '__main__':
    # Instantiate a scale dot-product attention with 4 attention heads
    sdp = ScaleDotProductAttention(num_heads=4)
    # Generate random query, key and value tensors
    #   Note: Sequence of length 64, single instance batch, 128 dim embeddings
    q, k, v = torch.randn(3, 64, 1, 128)
    # Export the attention module to ONNX
    torch.onnx.export(sdp, args=(q, k, v), f='sdp.onnx')

Get MAC operation counts by running

python -m qonnx.util.inference_cost sdp.onnx

Outputs something like

{'op_mac_FLOAT32_FLOAT32': 4194304.0, 'mem_w_FLOAT32': 0.0, 'mem_o_FLOAT32': 24576.0, 'unsupported': "{'Softmax', 'Pow', 'Constant'}", 'discount_sparsity': True, 'total_bops': 4294967296.0, 'total_mem_w_bits': 0.0, 'total_mem_o_bits': 786432.0}

Expected behavior

According to our cost model, the MAC count should be 2x HxTxTxd, which for the given example model is 2x 4x64x64x32 = 1048576.

Actual behavior

The MAC count is reported as 4194304, which is 4x (Hx) our expectation, indicating a cost function of 2x HxHxTxTxd.

Attached ONNX Graph

sdp.costs.onnx.zip sdp onnx

Harsh9650 commented 10 months ago

Hello Christoph, Thanks for highlighting the issue. We've addressed this issue at https://github.com/fastmachinelearning/qonnx/pull/90 . Please try it and let us know if you encounter any further issues.