ingonyama-zk / icicle

a GPU Library for Zero-Knowledge Acceleration
MIT License
296 stars 85 forks source link

[elliptic curve membership]: results not matching Sage / sppark #538

Open maciejskorski opened 4 weeks ago

maciejskorski commented 4 weeks ago

Description

The field arithmetic for BLS12-381 does not seem right.

I tested the example below in Sage and sppark GPU backend before reporting it here. It tests the membership of a valid EC point (found in Sage).

Reproduce

Compile and run the example below. I use this invocation:

nvcc -arch=sm_80 -gencode arch=compute_70,code=sm_70 -t0 -std=c++17 -DCURVE_ID=BLS12_381 -D__ADX__ -D__NVCC__ -I../../deps/icicle/icicle  -I./   curve_test.cu -o curve_test
// curve_test.cu

#include "curves/curve_config.cuh"
#include <iostream>
#include <fstream>
#include <nvml.h>
#include <unistd.h>
#include <cassert>
#include <cuda.h>

using namespace std;
using namespace curve_config;

typedef scalar_t fp_t;

__host__ __device__ fp_t fp_from_binary(const char k[], int N) {
    // field scalar from a bit representation, big byte order
    fp_t pow2 = fp_t::one();
    fp_t out = pow2-pow2;
    for(int i=0; i<N; i++) {
        if (k[N-1-i] == '1') {
            out = pow2 + out;
        }
        pow2 = pow2 + pow2;
    }
    return out;
}

int main() {
    // validate curve on host
    printf("testing the curve on the host...\n");
    fp_t const1 = fp_t::one();
    fp_t const4 = const1+const1+const1+const1;
    assert( (const1+const1)*(const1+const1)== const4 );
    assert( const4 == fp_from_binary("0100",4) );
    fp_t x = fp_from_binary("000010110100111000100010101000111100100100111001000101010101000010001011001110011101111101011100011000110100010110010110010001011010100111110001001100000100100010000000110110101001000111100001000011001110000000100111100101011011011000101011110101010011001101010011110111100000000100000111100101110111010110011010011000000101000101110000100111101001100001100101001100010000101000000100",384);
    fp_t y = fp_from_binary("000001110010001011111100111010100001010100010101100000011000101111101111010101010110011001001110001111101110110000010011010101100110110011000010000000011111111000001001011100000001110111010010011100001110000100111010101000100101100010100011000011100111010001011011100110011100000001011000010100011010001010011111011000101110010010111001010011011010001011011000100010001011100010110111",384);
    assert( y * y == x * x * x + const4 );
}

Expected Behavior

The assertion should pass.

Environment

Please complete the following information:

Linux 94d31facdf27 5.15.0-107-generic #117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

Additional context

The same works example in SageCalc:

p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab
F = GF(p)

xs="000010110100111000100010101000111100100100111001000101010101000010001011001110011101111101011100011000110100010110010110010001011010100111110001001100000100100010000000110110101001000111100001000011001110000000100111100101011011011000101011110101010011001101010011110111100000000100000111100101110111010110011010011000000101000101110000100111101001100001100101001100010000101000000100"
ys="000001110010001011111100111010100001010100010101100000011000101111101111010101010110011001001110001111101110110000010011010101100110110011000010000000011111111000001001011100000001110111010010011100001110000100111010101000100101100010100011000011100111010001011011100110011100000001011000010100011010001010011111011000101110010010111001010011011010001011011000100010001011100010110111"

x=F(Integer(xs,2))
y=F(Integer(ys,2))
y*y-x*x*x-F(4)

p_len = 381

x_bytes = Integer(x).to_bytes((p_len+7)//8, byteorder='big')
y_bytes = Integer(y).to_bytes((p_len+7)//8, byteorder='big')
with open("test.dat", "wb") as f:
    f.write(x_bytes)
    f.write(y_bytes)
goforashutosh commented 4 weeks ago

I am just reading the ICICLE code and your code but I think you are using the scalar field in curve_test.cu (scalar_t is defined in https://github.com/ingonyama-zk/icicle/blob/e19a869691af9379ddd21abccb4b37d03726f602/icicle/include/fields/snark_fields/bls12_381_scalar.cuh#L154C28-L154C36 ) when you should be using the base field (not sure right now but maybe try typedef Field<fq_config> fp_t; instead of typedef scalar_t fp_t;) (fq_config defined in https://github.com/ingonyama-zk/icicle/blob/e19a869691af9379ddd21abccb4b37d03726f602/icicle/include/fields/snark_fields/bls12_381_base.cuh#L8)

goforashutosh commented 4 weeks ago

Yes, you were using the scalar field instead of the base field. This code works:

// curve_test.cu
// icicle/icicle/include/curves/curve_config.cuh

#include "icicle/icicle/include/curves/curve_config.cuh"
#include <iostream>
#include <fstream>
#include <nvml.h>
#include <unistd.h>
#include <cassert>
#include <cuda.h>

using namespace std;
using namespace curve_config;

// typedef scalar_t fp_t;
typedef Field<fq_config> fp_t;

__host__ __device__ fp_t fp_from_binary(const char k[], int N) {
    // field scalar from a bit representation, big byte order
    fp_t pow2 = fp_t::one();
    fp_t out = pow2-pow2;
    for(int i=0; i<N; i++) {
        if (k[N-1-i] == '1') {
            out = pow2 + out;
        }
        pow2 = pow2 + pow2;
    }
    return out;
}

int main() {
    // validate curve on host
    printf("testing the curve on the host...\n");
    fp_t const1 = fp_t::one();
    fp_t const4 = const1+const1+const1+const1;
    assert( (const1+const1)*(const1+const1)== const4 );
    assert( const4 == fp_from_binary("0100",4) );
    fp_t x = fp_from_binary("000010110100111000100010101000111100100100111001000101010101000010001011001110011101111101011100011000110100010110010110010001011010100111110001001100000100100010000000110110101001000111100001000011001110000000100111100101011011011000101011110101010011001101010011110111100000000100000111100101110111010110011010011000000101000101110000100111101001100001100101001100010000101000000100",384);
    fp_t y = fp_from_binary("000001110010001011111100111010100001010100010101100000011000101111101111010101010110011001001110001111101110110000010011010101100110110011000010000000011111111000001001011100000001110111010010011100001110000100111010101000100101100010100011000011100111010001011011100110011100000001011000010100011010001010011111011000101110010010111001010011011010001011011000100010001011100010110111",384);
    assert( y * y == x * x * x + const4 );
    printf("Success");
}