robertknight / rten

ONNX neural network inference engine
121 stars 9 forks source link

Apple AMX support #18

Open robertknight opened 10 months ago

robertknight commented 10 months ago

The AMX unit on Apple Silicon devices provides a much faster way to do matrix multiplication.

AFAIK the only supported way to use this coprocessor is via the Accelerate library. However that limits us to the APIs that Accelerate exposes and doesn't provide the same flexibility with respect to fusing other operations into GEMM (eg. im2col, activations). If I'm honest, using the hardware directly is also more fun than wrapping a library. Nevertheless, it might be necessary to use this in a project that intends to be delivered via the App Store (since use of undocumented instructions could be considered use of a private API).

See https://github.com/corsix/amx/blob/main/References.md for links to a Rust project where this is already working.

robertknight commented 9 months ago

Baseline results on an M1 (tested on an AWS mac2.metal instance) using the ARM NEON kernel, which is not fully optimized yet.

Testing kernel arm-neon
m 512 n 512 k 512 iters 1000. Duration 930.885ms (0.930885ms/iter). GFLOPS 288.36588
m 1024 n 1024 k 1024 iters 125. Duration 795.782ms (6.3662558ms/iter). GFLOPS 337.32285
m 128 n 2048 k 512 iters 1000. Duration 1076.425ms (1.0764251ms/iter). GFLOPS 249.37685
m 2048 n 128 k 512 iters 1000. Duration 928.18ms (0.92818ms/iter). GFLOPS 289.20624
test gemm::tests::bench_gemm ... ok

For comparison, here are the peak results I could reach with Accelerate, which should be using AMX:

$ cargo install gemm-benchmark --features accelerate
$ gemm-benchmark -d 1024 -t 2
Threads: 2
Iterations per thread: 1000
Matrix shape: 1024 x 1024
GFLOPS: 1152.00

Note that the thread count doesn't make a difference here.

Meanwhile with BLIS, which AFAIK uses ARM NEON, I get:

$ gemm-benchmark -d 1024 -t 8
Threads: 8
Iterations per thread: 1000
Matrix shape: 1024 x 1024
GFLOPS: 477.90

With OpenBLAS (with OPENBLAS_NUM_THREADS=1, I tried without this, but results were worse):

$ gemm-benchmark -d 1024 -t 8
Threads: 8
Iterations per thread: 1000
Matrix shape: 1024 x 1024
GFLOPS: 377.73
robertknight commented 5 months ago

Apparently the M4 chip supports Arm's standard SME instructions as a replacement for AMX. This is good as hopefully it means that the undocumented AMX instruction set used by M1-M3 are "frozen", assuming we use SME instead when available, and I don't need to worry about backwards incompatible changes to it in future.