pq-code-package / mlkem-native

High-assurance, high-performance ML-KEM implementation for mobile, pc, and server targets
https://pq-code-package.github.io/mlkem-native/dev/bench/
Apache License 2.0
11 stars 9 forks source link

Enable rej_uniform native implementation in x86_64 #409

Closed mkannwischer closed 6 days ago

mkannwischer commented 1 week ago

This fixes a bug that caused our rej_uniform implementation to always use the C version instead of correctly using the native implementation when it should. This bug was introduced, when we added the x86_64 backend as the flag for enabling it was renamed from MLKEM_USE_NATIVE_AARCH64 to MLKEM_USE_NATIVE_REJ_UNIFORM. This was not correctly changed in the rej_uniform.c.

Below are the performance results on my 13th Gen Intel i7-1360P (Raptor Lake) using gcc 14.2.1 from the Arch Linux repo.

Before (6aa6118e):

ML-KEM-512: 22353, 27820, 35663 ML-KEM-768: 39626, 43605, 54916 ML-KEM-1024: 58983, 65402, 80370

After:

ML-KEM-512: 19652, 25165, 32945 ML-KEM-768: 33383, 37343, 48652 ML-KEM-1024: 47658, 54087, 69003

224

mkannwischer commented 1 week ago

Together with #410, this PR achieves this performance on my machine:

ML-KEM-512: 19590, 23861, 30397
ML-KEM-768:  33387, 35734 45130
ML-KEM-1024: 47644, 51798, 64815 

That's outperforming the code from the official Kyber repo.

rod-chapman commented 1 week ago

Great result!

mkannwischer commented 1 week ago

This may be due to some alignment issues. The Kyber code does align the buffer, while we do not. I don't have time to investigate now. I'll look at it later.

hanno-becker commented 6 days ago

@mkannwischer We can't have undocumented padding assumptions -- we either need to complicate our native API here, or fix this in a different way.

I tried changing the AVX2 code to not overread, and removed the AVX 128-bit tail. The performance numbers before/after are here:

Measured on c7i.xlarge

Without 128-bit tail:

   keypair cycles = 22463
    encaps cycles = 29165
    decaps cycles = 38154

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  22186  22274  22331  22366  22403  22463  22549  22805  22987  23263  24488
    encaps percentiles:  28850  28950  29005  29054  29102  29165  29247  29386  29595  30027  31129
    decaps percentiles:  37851  37958  38018  38055  38100  38154  38250  38417  38642  38936  40530

    keypair cycles = 38740
    encaps cycles = 42251
    decaps cycles = 55424

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  38027  38278  38421  38549  38634  38740  38864  39007  39252  39813  41013
    encaps percentiles:  41790  41937  42038  42111  42185  42251  42356  42480  42690  43198  44812
    decaps percentiles:  54762  54974  55090  55212  55294  55424  55554  55718  55926  56517  57886

    keypair cycles = 53546
    encaps cycles = 59473
    decaps cycles = 77068

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  52816  53152  53248  53352  53453  53546  53623  53746  53919  55186  56879
    encaps percentiles:  58921  59120  59223  59312  59377  59473  59592  59721  59982  61101  62263
    decaps percentiles:  76332  76636  76764  76878  76972  77068  77191  77341  77592  78606  80070

====================

With 128-bit tail:

    keypair cycles = 22046
    encaps cycles = 28421
    decaps cycles = 37342

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  21872  21930  21955  21977  22001  22046  22218  22421  22631  23021  24730
    encaps percentiles:  28252  28331  28360  28383  28400  28421  28446  28491  28661  29045  30517
    decaps percentiles:  37144  37204  37246  37286  37311  37342  37413  37522  37746  38312  39502

    keypair cycles = 37742
    encaps cycles = 41222
    decaps cycles = 54438

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  37065  37232  37378  37494  37615  37742  37902  38186  38579  39220  40190
    encaps percentiles:  41046  41107  41138  41162  41190  41222  41302  41525  41631  42267  43848
    decaps percentiles:  53825  53927  54006  54088  54266  54438  54583  54783  55070  55799  57065

    keypair cycles = 52383
    encaps cycles = 57990
    decaps cycles = 75675

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  51865  52079  52182  52248  52312  52383  52465  52540  52690  54124  56110
    encaps percentiles:  57702  57809  57851  57892  57939  57990  58073  58174  58368  59857  60602
    decaps percentiles:  74983  75234  75360  75478  75582  75675  75801  75953  76295  77272  78286

We loose a little bit of performance by not doing the 128-bit tail, so I'll have another look if there's a simple way to fix it. But if not, I think this is still a more robust approach than going via padding.

hanno-becker commented 6 days ago

With shortened 128-bit tail:

    keypair cycles = 22046
    encaps cycles = 28440
    decaps cycles = 37288

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  21874  21938  21961  21985  22008  22046  22163  22412  22623  23001  24065
    encaps percentiles:  28328  28366  28388  28406  28423  28440  28464  28509  28646  29177  30683
    decaps percentiles:  37116  37169  37199  37235  37261  37288  37331  37413  37576  38241  39548

    keypair cycles = 37763
    encaps cycles = 41192
    decaps cycles = 54268

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  37112  37252  37424  37552  37653  37763  37900  38129  38438  39192  40556
    encaps percentiles:  41027  41069  41096  41126  41153  41192  41277  41470  41594  42163  43596
    decaps percentiles:  53790  53875  53948  54005  54129  54268  54393  54606  54863  55519  56650

    keypair cycles = 52031
    encaps cycles = 58043
    decaps cycles = 75697

           percentile      1     10     20     30     40     50     60     70     80     90     99
   keypair percentiles:  51544  51701  51828  51906  51961  52031  52106  52221  52329  53832  55546
    encaps percentiles:  57752  57845  57886  57930  57987  58043  58174  58317  58547  59925  61176
    decaps percentiles:  74945  75227  75359  75477  75544  75697  75860  76065  76412  77446  79046
mkannwischer commented 6 days ago

Nice - thanks @hanno-becker. This is indeed much cleaner. I am happy with this change now.