pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
709 stars 91 forks source link

[llama] Use horizontal fusion trick from Attention for FeedForward #606

Open cpuhrsch opened 1 month ago

cpuhrsch commented 1 month ago

For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.

https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L220-L225 and https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L230-L231

I suspect we can do the exact same thing for FeedFoward

https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L262-L263

Task: Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.

sanchitintel commented 1 month ago

Is the LLaMA3 benchmark currently working at your end with Weight-only Quantization using torch.compile? Thanks in advance!

sayakpaul commented 1 week ago

suspect we can do the exact same thing for FeedFoward

How would you account for silu here?

self.w2(F.silu(self.w1(x)) * self.w3(x)) 

Something like this could work:

class FusedOperation(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fused_weight = nn.Parameter(torch.randn(output_dim, input_dim * 2))
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, x):
        # Split the fused weight into two parts
        w1_w3, w2 = self.fused_weight.chunk(2, dim=1)

        # Compute the fused operation
        hidden = F.silu(x @ w1_w3[:, :x.size(-1)].t()) * (x @ w1_w3[:, x.size(-1):].t())
        return hidden @ w2.t() + self.bias

Or do you have a simpler alternative in mind?

cpuhrsch commented 1 week ago

@sayakpaul Oh, I mean

x1, x3 = self.w13(x).split([...])

As in, just fuse w1 and w3 not all of w1, w2 and w3. Similar to how wqkv fuses wq, wk and wv, but leaves the output projections (wo) alone.

So more specifically

x1, x3 = self.w13(x).split([...])
return self.w2(F.silu(x1) * x3)

Right now

self.w2(F.silu(self.w1(x)) * self.w3(x))

will cause 3 calls to an nn.Linear, but with the above change it's 2 calls and also F.silu(x1) * x3 can become an epilogue of w13(x) if using torch.compile.

Essentially you stack w1 and w3 horizontally like

[w1,
 w3] @ x

instead of

[w1 @ x,
 w3 @ x]

but you can split the result of the former (and do so without causing a copy, because striding).