Closed msaroufim closed 1 month ago
whats the output code look like on rasberry pi? just the normal cpp codegen?
Yeah just regular cpp codegen. For context as far as optimal performance for ARM there's a few parallel efforts
torchao/experimental/
So I'm still just poking around here and trying to get basic hello worlds for each. My suspicsion as to what might be msising from the cpp codegen path is calling some highly optimized ARM matmuls
# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
cpp_fused_add_mm_mul_0 = async_compile.cpp_pybinding(['const float*', 'const int8_t*', 'const float*', 'const float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_marksaroufim/bj/cbjw7ceat7dc2vi3uyeg6jzl5efohgngn7rvf7ckrkhewlgxvsvc.h"
extern "C" void kernel(const float* in_ptr0,
const int8_t* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
float* out_ptr0,
float* out_ptr1)
{
#pragma omp parallel num_threads(4)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2000L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1000L); x1+=static_cast<int64_t>(8L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), static_cast<int64_t>(8));
auto tmp1 = at::vec::Vectorized<int8_t>::loadu(in_ptr1 + static_cast<int64_t>(x1 + (1000L*x0)), static_cast<int64_t>(8));
auto tmp2 = at::vec::convert<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
tmp_acc0_vec = tmp_acc0_vec + tmp3;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr0[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
}
}
#pragma omp single
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2000L); x0+=static_cast<int64_t>(8L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
auto tmp2 = tmp0 * tmp1;
auto tmp4 = tmp2 + tmp3;
tmp4.store(out_ptr1 + static_cast<int64_t>(x0));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4, primals_5 = args
args.clear()
assert_size_stride(primals_1, (2000, 1000), (1000, 1))
assert_size_stride(primals_2, (2000, ), (1, ))
assert_size_stride(primals_3, (2000, ), (1, ))
assert_size_stride(primals_4, (2000, ), (1, ))
assert_size_stride(primals_5, (1, 1000), (1000, 1))
buf0 = empty_strided_cpu((1, 2000), (2000, 1), torch.float32)
buf1 = empty_strided_cpu((1, 2000), (2000, 1), torch.float32)
cpp_fused_add_mm_mul_0(primals_5, primals_1, primals_2, primals_4, buf0, buf1)
del buf0
del primals_1
del primals_2
del primals_4
del primals_5
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((2000, 1000), (1000, 1), device='cpu', dtype=torch.int8)
primals_2 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.float32)
primals_3 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.int64)
primals_4 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.float32)
primals_5 = rand_strided((1, 1000), (1000, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Problem
We don't publish aarch64 linux binaries so right now we still install ao=0.1
But torchao actually works on a raspberry pi
Environment details
Test file I ran
Logs
Test file
Torch.compile works
Bf16 works
Full test suite