pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.63k stars 181 forks source link

torchao already works on raspberry pi #1076

Closed msaroufim closed 1 month ago

msaroufim commented 1 month ago

Problem

We don't publish aarch64 linux binaries so right now we still install ao=0.1

(myvenv) marksaroufim@rpi5:~/Dev/ao $ pip install torchao
Looking in indexes: https://pypi.org/simple, https://www.piwheels.org/simple
Collecting torchao
  Downloading torchao-0.1-py3-none-any.whl (54 kB)

But torchao actually works on a raspberry pi

Environment details

(myvenv) marksaroufim@rpi5:~/Dev/ao $ uname -a
Linux rpi5 6.6.51+rpt-rpi-2712 #1 SMP PREEMPT Debian 1:6.6.51-1+rpt3 (2024-10-08) aarch64 GNU/Linux

Test file I ran

Logs

(myvenv) marksaroufim@rpi5:~/Downloads $ python test.py 
/home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:294: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:82.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
Model size on disk: 7821.85 KB
Model size on disk: 1988.05 KB

Test file

import torch
import torch.nn as nn
import os

class ToyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ToyModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        # self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x = torch.relu(self.layer1(x))
        x = self.layer1(x)
        return x

# Create an instance of the model
input_size = 1000
hidden_size = 2000
output_size = 5
model = ToyModel(input_size, hidden_size, output_size)

# Save the model
torch.save(model.state_dict(), "toy_model.pth")

# Get the size of the saved model file
file_size = os.path.getsize("toy_model.pth")

print(f"Model size on disk: {file_size / 1024:.2f} KB")

import torchao
from torchao.quantization.quant_api import (quantize_, int8_weight_only, int4_weight_only)

quantize_(model, int8_weight_only())
torch.save(model.state_dict, "quantized_toy_model.pth")

file_size = os.path.getsize("quantized_toy_model.pth")
print(f"Model size on disk: {file_size / 1024:.2f} KB")

Torch.compile works

import torch
import torch.nn as nn
import os

class ToyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ToyModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        # self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x = torch.relu(self.layer1(x))
        x = self.layer1(x)
        return x

# Create an instance of the model
input_size = 1000
hidden_size = 2000
output_size = 5
model = ToyModel(input_size, hidden_size, output_size)

sample_input = torch.randn(1, input_size)
model(sample_input)

# Save the model
torch.save(model.state_dict(), "toy_model.pth")

# Get the size of the saved model file
file_size = os.path.getsize("toy_model.pth")

print(f"Model size on disk: {file_size / 1024:.2f} KB")

import torchao
from torchao.quantization.quant_api import (quantize_, int8_weight_only, int4_weight_only)

quantize_(model, int8_weight_only())
model = torch.compile(model)
model(sample_input)
torch.save(model.state_dict, "quantized_toy_model.pth")

file_size = os.path.getsize("quantized_toy_model.pth")
print(f"Model size on disk: {file_size / 1024:.2f} KB")

Bf16 works

import torch

# Create two bf16 tensors
tensor1 = torch.randn(1000, dtype=torch.bfloat16)
tensor2 = torch.randn(1000, dtype=torch.bfloat16)

# Perform dot product
result = torch.dot(tensor1, tensor2)

# Print the result
print("Dot product result:", result)

Full test suite

(myvenv) marksaroufim@rpi5:~/Dev/ao $ rm test/prototype/test_spinquant.py 
(myvenv) marksaroufim@rpi5:~/Dev/ao $ pytest test
============================= test session starts ==============================
platform linux -- Python 3.11.2, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/marksaroufim/Dev/ao
plugins: hypothesis-6.115.2
collected 1921 items / 10 skipped                                              

test/test_ao_models.py ....                                              [  0%]
test/test_ops.py sssssssssssssssssssssssssssssssssssssssssssssssssssssss [  3%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  6%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 10%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 14%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 18%]
sssssssssssssssss                                                        [ 18%]
test/test_utils.py ..                                                    [ 19%]
test/dtypes/test_affine_quantized.py sssssssssssssss                     [ 19%]
test/dtypes/test_affine_quantized_float.py ssssssssssssssssssssssss..sss [ 21%]
s.                                                                       [ 21%]
test/dtypes/test_affine_quantized_tensor_parallel.py sssssssss           [ 21%]
test/dtypes/test_bitnet.py .........sss                                  [ 22%]
test/dtypes/test_bitpacking.py .....................ssssssssssssssssssss [ 24%]
sssssssssssssssssssssss.                                                 [ 25%]
test/dtypes/test_floatx.py ssss.....ss..                                 [ 26%]
test/dtypes/test_nf4.py ...ssssssssssssssssssssssssssssssss...sss...ssss [ 29%]
ss.s.....s.....sss...............ss.......s                              [ 31%]
test/dtypes/test_uint2.py ........                                       [ 31%]
test/dtypes/test_uint4.py sss                                            [ 31%]
test/dtypes/test_uintx.py ssssssssssssssssssssssssssssssssssssssssssssss [ 34%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%]
ssssssss                                                                 [ 38%]
test/float8/test_base.py ............sssssssssssssssssssssssssssssssssss [ 40%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 44%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 48%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.sssssss [ 52%]
sssssssssss......                                                        [ 53%]
test/float8/test_compile.py ssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 59%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssss                 [ 61%]
test/float8/test_numerics_integration.py sssssssssssssssssssssssssssss   [ 63%]
test/hqq/test_hqq_affine.py sssssss                                      [ 63%]
test/integration/test_integration.py ....s...sss.s.s.sssssssssssssssssss [ 65%]
ssssssssssssssssssss.sssss.sss...sss...sssssssssss.sssss.sssssssssssssss [ 69%]
ssssssssssss...sss...sss..ssssssssssssssssssss.sss...ssss.s.ssssssssssss [ 73%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 76%]
ssssssssssssssssss.s                                                     [ 77%]
test/kernel/test_autotuner.py sssss.s.                                   [ 78%]
test/profiler/test_performance_counter.py sssssssss..ssssssss            [ 79%]
test/prototype/test_awq.py sssss                                         [ 79%]
test/prototype/test_bitpacking_gen.py ......                             [ 79%]
test/prototype/test_low_bit_optim.py ....ssssssss............s           [ 81%]
test/prototype/test_quantized_training.py ssssssssss...............ss    [ 82%]
test/prototype/test_splitk.py ss                                         [ 82%]
test/prototype/mx_formats/test_custom_cast.py ss...............ss...s.s. [ 84%]
s.s.s                                                                    [ 84%]
test/prototype/mx_formats/test_mx_linear.py ssssssssssssssssssssssssssss [ 85%]
ssssssssssssssssssssssssssssssssssssssssssssssss.                        [ 88%]
test/prototype/mx_formats/test_mx_tensor.py ssssssssssssssssssssssssssss [ 89%]
ssssssssssssssssssssss.....ssssssssssssssssssss                          [ 92%]
test/quantization/test_mixed_precision.py .                              [ 92%]
test/quantization/test_observer.py ......                                [ 92%]
test/quantization/test_qat.py ....sss........                            [ 93%]
test/quantization/test_quant_api.py s.s.sssssssss..sssssssss             [ 94%]
test/quantization/test_quant_primitives.py .........s.............       [ 95%]
test/sparsity/test_fast_sparse_training.py ss                            [ 95%]
test/sparsity/test_marlin.py sss                                         [ 96%]
test/sparsity/test_parametrization.py ....                               [ 96%]
test/sparsity/test_scheduler.py ......                                   [ 96%]
test/sparsity/test_sparse_api.py sssssssss                               [ 97%]
test/sparsity/test_sparsifier.py ....................                    [ 98%]
test/sparsity/test_sparsity_utils.py ........                            [ 98%]
test/sparsity/test_structured_sparsifier.py ......................       [ 99%]
test/sparsity/test_wanda.py .....                                        [100%]

=============================== warnings summary ===============================
test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_bfloat16
test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_float16
test/dtypes/test_nf4.py::TestNF4Linear::test_to_copy_float32
  /home/marksaroufim/Dev/ao/test/dtypes/test_nf4.py:205: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
    torch.testing.assert_allclose(input_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

test/integration/test_integration.py::TestSaveLoadMeta::test_save_load_int4woqtensors_2_cpu
test/integration/test_integration.py::TestSaveLoadMeta::test_save_load_int8woqtensors_0_cpu
test/integration/test_integration.py::TestSaveLoadMeta::test_save_load_int8woqtensors_1_cpu
test/integration/test_integration.py::TestSaveLoadMeta::test_save_load_int8woqtensors_2_cpu
  /home/marksaroufim/Dev/ao/test/integration/test_integration.py:1038: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    state_dict = torch.load("test.pth", mmap=True)

test/kernel/test_autotuner.py::TestQuantFlow::test_int_scaled_mm_1_cpu
test/kernel/test_autotuner.py::TestQuantFlow::test_int_scaled_mm_3_cpu
  /home/marksaroufim/Dev/ao/test/kernel/test_autotuner.py:97: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
    torch.testing.assert_allclose(out32_1, out32_2)

test/profiler/test_performance_counter.py::test_performance_stats[no_device]
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/profiler/performance_counter.py:381: UserWarning: Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation
    warnings.warn(

test/profiler/test_performance_counter.py::test_performance_stats[no_device]
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/profiler/performance_counter.py:361: UserWarning: Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation
    warnings.warn(

test/profiler/test_performance_counter.py::test_performance_stats[no_device]
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/profiler/performance_counter.py:391: UserWarning: Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation
    warnings.warn(

test/profiler/test_performance_counter.py::test_performance_stats[no_device]
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/profiler/performance_counter.py:371: UserWarning: Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation
    warnings.warn(

test/prototype/test_low_bit_optim.py: 12 warnings
  /home/marksaroufim/Dev/ao/test/prototype/test_low_bit_optim.py:103: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    state_dict = torch.load(f.name, map_location="cpu")

test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims0_bias_False_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims0_bias_True_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims1_bias_False_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims1_bias_True_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims2_bias_False_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_compile_leading_dims2_bias_True_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_training_compile_True_device_cpu
test/prototype/test_quantized_training.py::TestQuantizedTraining::test_int8_weight_only_training_compile_True_device_cpu
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py:651: UserWarning: The config.capture_autograd_function flag is deprecated, it's now always true.
    warnings.warn(

test/sparsity/test_parametrization.py::TestFakeSparsity::test_jit_trace
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/sparsity/prototype/sparsifier/utils.py:129: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    assert self.mask.shape == x.shape

test/sparsity/test_scheduler.py::TestScheduler::test_lambda_scheduler
test/sparsity/test_scheduler.py::TestCubicScheduler::test_step
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/sparsity/prototype/scheduler/base_scheduler.py:122: UserWarning: Detected call of `scheduler.step()` before `sparsifier.step()`. You have to make sure you run the sparsifier.step() BEFORE any calls to the scheduler.step().
    warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. "

test/sparsity/test_wanda.py::TestWandaSparsifier::test_one_layer_mlp_2x4
  /home/marksaroufim/Dev/myvenv/lib/python3.11/site-packages/torchao/sparsity/wanda.py:42: UserWarning: WandaSparsifier got semi_structured_bock_size=4, sparsity_level fixed to 50% (2:4) sparsity
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========== 330 passed, 1601 skipped, 37 warnings in 436.75s (0:07:16) ==========
(myvenv) marksaroufim@rpi5:~/Dev/ao $ 
drisspg commented 1 month ago

whats the output code look like on rasberry pi? just the normal cpp codegen?

msaroufim commented 1 month ago

Yeah just regular cpp codegen. For context as far as optimal performance for ARM there's a few parallel efforts

  1. The inductor cpp codegen path
  2. A triton ARM backend
  3. Custom triton kernels like the ones in torchao/experimental/

So I'm still just poking around here and trying to get basic hello worlds for each. My suspicsion as to what might be msising from the cpp codegen path is calling some highly optimized ARM matmuls

# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()

cpp_fused_add_mm_mul_0 = async_compile.cpp_pybinding(['const float*', 'const int8_t*', 'const float*', 'const float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_marksaroufim/bj/cbjw7ceat7dc2vi3uyeg6jzl5efohgngn7rvf7ckrkhewlgxvsvc.h"
extern "C"  void kernel(const float* in_ptr0,
                       const int8_t* in_ptr1,
                       const float* in_ptr2,
                       const float* in_ptr3,
                       float* out_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2000L); x0+=static_cast<int64_t>(1L))
            {
                {
                    float tmp_acc0 = 0;
                    at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1000L); x1+=static_cast<int64_t>(8L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), static_cast<int64_t>(8));
                        auto tmp1 = at::vec::Vectorized<int8_t>::loadu(in_ptr1 + static_cast<int64_t>(x1 + (1000L*x0)), static_cast<int64_t>(8));
                        auto tmp2 = at::vec::convert<float>(tmp1);
                        auto tmp3 = tmp0 * tmp2;
                        tmp_acc0_vec = tmp_acc0_vec + tmp3;
                    }
                    tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
                    out_ptr0[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
                }
            }
        }
        #pragma omp single
        {
            {
                for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2000L); x0+=static_cast<int64_t>(8L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
                    auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
                    auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x0), static_cast<int64_t>(8));
                    auto tmp2 = tmp0 * tmp1;
                    auto tmp4 = tmp2 + tmp3;
                    tmp4.store(out_ptr1 + static_cast<int64_t>(x0));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2, primals_3, primals_4, primals_5 = args
    args.clear()
    assert_size_stride(primals_1, (2000, 1000), (1000, 1))
    assert_size_stride(primals_2, (2000, ), (1, ))
    assert_size_stride(primals_3, (2000, ), (1, ))
    assert_size_stride(primals_4, (2000, ), (1, ))
    assert_size_stride(primals_5, (1, 1000), (1000, 1))
    buf0 = empty_strided_cpu((1, 2000), (2000, 1), torch.float32)
    buf1 = empty_strided_cpu((1, 2000), (2000, 1), torch.float32)
    cpp_fused_add_mm_mul_0(primals_5, primals_1, primals_2, primals_4, buf0, buf1)
    del buf0
    del primals_1
    del primals_2
    del primals_4
    del primals_5
    return (buf1, )

def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((2000, 1000), (1000, 1), device='cpu', dtype=torch.int8)
    primals_2 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.float32)
    primals_3 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.int64)
    primals_4 = rand_strided((2000, ), (1, ), device='cpu', dtype=torch.float32)
    primals_5 = rand_strided((1, 1000), (1000, 1), device='cpu', dtype=torch.float32)
    fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5])
    return print_performance(fn, times=times, repeat=repeat)

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)