Open cpuhrsch opened 1 month ago
Is the LLaMA3 benchmark currently working at your end with Weight-only Quantization using torch.compile
? Thanks in advance!
suspect we can do the exact same thing for FeedFoward
How would you account for silu
here?
self.w2(F.silu(self.w1(x)) * self.w3(x))
Something like this could work:
class FusedOperation(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fused_weight = nn.Parameter(torch.randn(output_dim, input_dim * 2))
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, x):
# Split the fused weight into two parts
w1_w3, w2 = self.fused_weight.chunk(2, dim=1)
# Compute the fused operation
hidden = F.silu(x @ w1_w3[:, :x.size(-1)].t()) * (x @ w1_w3[:, x.size(-1):].t())
return hidden @ w2.t() + self.bias
Or do you have a simpler alternative in mind?
@sayakpaul Oh, I mean
x1, x3 = self.w13(x).split([...])
As in, just fuse w1 and w3 not all of w1, w2 and w3. Similar to how wqkv fuses wq, wk and wv, but leaves the output projections (wo) alone.
So more specifically
x1, x3 = self.w13(x).split([...])
return self.w2(F.silu(x1) * x3)
Right now
self.w2(F.silu(self.w1(x)) * self.w3(x))
will cause 3 calls to an nn.Linear, but with the above change it's 2 calls and also F.silu(x1) * x3
can become an epilogue of w13(x)
if using torch.compile.
Essentially you stack w1 and w3 horizontally like
[w1,
w3] @ x
instead of
[w1 @ x,
w3 @ x]
but you can split the result of the former (and do so without causing a copy, because striding).
For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.
https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L220-L225 and https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L230-L231
I suspect we can do the exact same thing for FeedFoward
https://github.com/pytorch/ao/blob/22d6f97d8584dd14ac5d0c5bc4ddad9bf33553fe/torchao/_models/llama/model.py#L262-L263
Task: Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.