sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.18k stars 210 forks source link

Bad performance in LirMatmulUnary with maybe unexpected loop #830

Open tgolsson opened 1 year ago

tgolsson commented 1 year ago

Hello again!

As usual, this is something I'm still investigating but want to open up for input. I've been debugging a text model over the last few days, trying to figure out why Tract is so much slower than ORT for it (6x at no-batching, 2x at 6 batch-size). This is the same funky model that has the batch-dimension in the second position - ripe for issues. I've narrowed it down to likely being in LirMatMulUnary (as in - the LirMatMulUnary is much slower than any other op).

What I've discovered is that (most of) the MatMuls in this network ends up dispatching over the 77 feature dimensions. I don't quite think that's the right thing to do here; and I'm not even sure why. Looking at the graph, the exploding matmul (I think) has input=[N, 512] weights=[512,1536]. I'm wondering whether there's some assumption that the batch dimension is first , and that causes some misunderstanding of the model where N=6 != 77?

On the other hand, I can't even find where the 1536 disappears there, which is equally confusing. I'd have maybe understood if 1536 / 77 worked , but that's a fraction.

Maybe related, this model sometimes crashes in Softmax, so I'm wondering if there's a bug here causing it to sometimes read out of bounds...

Edit 1: So A = [77,6,512] if I trace the graph correctly. And then tract takes the 1536 and splits it into three calls of 512+512+512. So it does 3x MatMul that is further split into 77 x MatMul.

Edit 2: Ok; ignoring the crash for now I think this is expected behavior - this is a BMM matrix multiply, so 77 of [6x512] @ [512, 1536], optimized to account for the later slice. So the question is why it's so slow, especially in the N=1 case.

tgolsson commented 1 year ago

In this code here:

https://github.com/sonos/tract/blob/dd622bb15dc2f22dd22307963a084767a72e037c/core/src/ops/matmul/lir_unary.rs#L229-L239

I wonder if the whole loop is meaningless there when the loop is over the outermost dims. The memory layout (assuming tensors are C-contiguous) is with the iteration variable at the largest stride... which is what it'd have if we just merged the two outer dims (=reshape, vstack,...). A simple heuristic would be if both c_m_axis and c_n_axis are 0 or 1. That implies that the leftmost dims are contiguous.

tgolsson commented 1 year ago

Doing this transform in Python seems to check out in tests; and offers some decent performance boosts:

# LINUX
With bs==1
4.81 ms ± 334 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
552 µs ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

With bs==3
3.86 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.15 ms ± 35.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

With bs==6
3.96 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.96 ms ± 91.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# Windows
With bs==1
22.9 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
753 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

With bs==3
29.8 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.61 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

With bs==6
31.1 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
6.35 ms ± 36.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This is the same machine, same script, same minor version of numpy... so much bigger impact on Windows, but good savings on both. This is with the dims I'm seeing in this specific network, obviously.

kali commented 1 year ago

Hey @tgolsson, could you give me this model ? I have a couple of ideas about what is at play here. We are also doing a relatively heavy refactoring of the MatMul ops to get closer to EinSum operator. The idea is to lift some constraints on the axes ordering that should ultimately unlock more optimum execution plans. For instance, affecting n to some contiguous axis in B, and iterate over the rest, not necessarily the inner (or last-to-inner) one as it was.

tgolsson commented 1 year ago

This is the same model as in the other one; text.{onnx,nnef.tar}. That sounds like a good refactor, do you have a timeline for it? I realize my track record for timely PRs isn't great (I'll get around to it someday, I swear); but simplifying this "outer loop" seems easy enough here.

kali commented 1 year ago

Merging axis to get a more favorable n (or m) count is indeed possible. I'm just a bit concerned about the possible interaction with the fusing of operators. This code is not... super clean even for tract standards. Basically, linalg will not be able to run fused operators that depends on geometry (like AddRow/Col) if we merge axis to get a bigger M (or resp N).

We have sevaral ways to go here... do we merge axes eagerly (like during MirUnaryMatMul codegen) or later on (LirMatMulUnary declutter ?) Or sneakily in the eval ?

Impairing the fusing capabilities may be worth it, as we will typically get a x3 or x4 speedup with a big n vs a vector. But in the quantized case, we need to make sure the requantisation pipeline will still be fusable...