miekg / pkcs11

pkcs11 wrapper for Go
BSD 3-Clause "New" or "Revised" License
375 stars 133 forks source link

Using AWS CloudHSM to sign ethereum transactions #168

Closed ThreeAndTwo closed 1 year ago

ThreeAndTwo commented 1 year ago
  1. Create a keypair for ECC secp256k1 on AWS CloudHSM.
    genECCKeyPair -i 16 -l ether_test

    genECC

  2. code
package main

import (
    "bytes"
    "crypto/ecdsa"
    "encoding/asn1"
    "fmt"
    "github.com/btcsuite/btcd/btcec/v2"
    "github.com/ethereum/go-ethereum/common/hexutil"
    ethCrypto "github.com/ethereum/go-ethereum/crypto"
    "github.com/miekg/pkcs11"
    "log"
    "math/big"
    "strconv"
)

type dsaSignature struct {
    R, S *big.Int
}

type ecdsaSignature dsaSignature

const path = "/opt/cloudhsm/lib/libcloudhsm_pkcs11.so"

func main() {
    // Open PKCS#11 library
    lib := path
    p := pkcs11.New(lib)
    if p == nil {
        panic("Failed to load PKCS#11 library")
    }
    defer p.Destroy()

    // Initialize PKCS#11 library
    err := p.Initialize()
    if err != nil {
        panic(fmt.Sprintf("Failed to initialize PKCS#11 library: %s", err))
    }
    defer p.Finalize()

    // Find the first slot
    slots, err := p.GetSlotList(true)
    if err != nil {
        panic(fmt.Sprintf("Failed to get slots: %s", err))
    }
    if len(slots) == 0 {
        panic("No slots found")
    }
    slot := slots[0]
    fmt.Println("slot: ", slot)

    // Open a session
    session, err := p.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
    if err != nil {
        panic(fmt.Sprintf("Failed to open session: %s", err))
    }
    defer p.CloseSession(session)

    // Login
    err = p.Login(session, pkcs11.CKU_USER, "user:password")
    if err != nil {
        panic(fmt.Sprintf("Failed to login: %s", err))
    }
    defer p.Logout(session)

    // public key
    pubKeyBytes, err := p.GetAttributeValue(session, 262159, []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil)})
    if err != nil {
        fmt.Println("Error getting public key bytes:", err)
        return
    }

    pubKeyX := new(big.Int).SetBytes(pubKeyBytes[0].Value[1:33])
    pubKeyY := new(big.Int).SetBytes(pubKeyBytes[0].Value[33:65])

    pubkey := &ecdsa.PublicKey{
        Curve: btcec.S256(),
        X:     pubKeyX,
        Y:     pubKeyY,
    }

    addr := ethCrypto.PubkeyToAddress(*pubkey)
    fmt.Println("addr:", addr)

    hsmPubkey := ethCrypto.FromECDSAPub(pubkey)

    message := []byte("Sign me!")

    ethMessage := append([]byte("\x19Ethereum Signed Message:\n"+strconv.Itoa(len(message))), message...)
    hashData := ethCrypto.Keccak256(ethMessage)
    flag := 0

    for {
        fmt.Printf("flag %d \n", flag)

        e := p.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, 262158)
        if e != nil {
            log.Fatalf("failed to sign: %s\n", e)
            return
        }

        hsmSig, e := p.Sign(session, hashData)
        if e != nil {
            log.Fatalf("failed to sign: %s\n", e)
        }

        r := new(big.Int).SetBytes(hsmSig[:len(hsmSig)/2])
        s := new(big.Int).SetBytes(hsmSig[len(hsmSig)/2:])

        parsedSig, err := asn1.Marshal(ecdsaSignature{r, s})
        if err != nil {
            panic("ans1 unmarshal esig error: " + err.Error())
            return
        }
        fmt.Printf("ECDSA Signature (%d bytes): %x\n", len(parsedSig), parsedSig)

        var esig ecdsaSignature
        _, err = asn1.Unmarshal(parsedSig, &esig)
        if err != nil {
            panic("asn1 unmarshal error:" + err.Error())
        }
        var ethFormatSig []byte
        ethFormatSig = append(ethFormatSig, esig.R.Bytes()...)
        ethFormatSig = append(ethFormatSig, esig.S.Bytes()...)
        fmt.Println(hexutil.Encode(ethFormatSig))

        valid := ethCrypto.VerifySignature(hsmPubkey, hashData, ethFormatSig)
        if !valid {
            flag++
            continue
        } else {
            sigWith0 := append(ethFormatSig, byte(0))
            _pubkey, eErr := ethCrypto.Ecrecover(hashData, sigWith0)
            if eErr != nil {
                log.Fatalf("ecrCover0 error: %s", eErr)
                return
            }
            if bytes.Equal(_pubkey, hsmPubkey) {
                log.Printf(hexutil.Encode(sigWith0))
                return
            }
            sigWith1 := append(ethFormatSig, byte(1))
            _pubkey, err = ethCrypto.Ecrecover(hashData, sigWith1)
            if err != nil {
                log.Fatalf("ecrCover1 error: %s", eErr)
                return
            }
            if bytes.Equal(_pubkey, hsmPubkey) {
                log.Printf(hexutil.Encode(sigWith1))
                return
            }
            fmt.Printf("sign error with 0 | 1 v value")
            return
        }
    }
}
  1. miekg/pkcs11 provides the sign method, but it does not meet the rules of Ethereum signature. How can I correct this error or is there any other way to achieve it?
ThreeAndTwo commented 1 year ago

The issue has been fixed, and below is a sample code that can be used.

package main

import (
    "bytes"
    "context"
    "crypto/ecdsa"
    "encoding/asn1"
    "errors"
    "fmt"
    "github.com/ethereum/go-ethereum/common"
    "github.com/ethereum/go-ethereum/common/hexutil"
    "github.com/ethereum/go-ethereum/core/types"
    ethCrypto "github.com/ethereum/go-ethereum/crypto"
    "github.com/ethereum/go-ethereum/ethclient"
    "github.com/miekg/pkcs11"
    "github.com/sirupsen/logrus"
    "log"
    "math/big"
)

type dsaSignature struct {
    R, S *big.Int
}

type ecdsaSignature dsaSignature

const path = "/opt/cloudhsm/lib/libcloudhsm_pkcs11.so"

// parse ECDSA public key
func parseECDSAPublicKey(pubKeyBytes []byte) (*ecdsa.PublicKey, error) {
    var raw asn1.RawValue

    _, err := asn1.Unmarshal(pubKeyBytes, &raw)
    if err != nil {
        return nil, fmt.Errorf("asns unmarshal error: %s", err.Error())
    }

    uncompressedPubKey := raw.Bytes
    curve := ethCrypto.S256() // using the secp256k1 curve of ethereum

    if len(uncompressedPubKey) != 65 || uncompressedPubKey[0] != 0x04 {
        return nil, errors.New("invalid uncompressed public key format")
    }

    pubKey := &ecdsa.PublicKey{
        Curve: curve,
        X:     new(big.Int).SetBytes(uncompressedPubKey[1:33]),
        Y:     new(big.Int).SetBytes(uncompressedPubKey[33:]),
    }

    if !curve.IsOnCurve(pubKey.X, pubKey.Y) {
        return nil, errors.New("public key is not on the curve")
    }

    return pubKey, nil
}

func HsmSign(p *pkcs11.Ctx, session pkcs11.SessionHandle, pbKeyId, pvKeyId int, hashData []byte) ([]byte, error) {
    //
    pubKeyBytes, err := p.GetAttributeValue(session, pkcs11.ObjectHandle(pbKeyId), []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil)})
    if err != nil {
        fmt.Println("Error getting public key bytes:", err)
        return nil, err
    }

    pubkey, err := parseECDSAPublicKey(pubKeyBytes[0].Value)
    if err != nil {
        fmt.Println("parseECDSAPublicKey:", err)
        return nil, err
    }

    addr := ethCrypto.PubkeyToAddress(*pubkey)
    fmt.Println("addr:", addr)

    hsmPubkey := ethCrypto.FromECDSAPub(pubkey)
    //message := []byte("Sign me!")

    //ethMessage := append([]byte("\x19Ethereum Signed Message:\n"+strconv.Itoa(len(message))), message...)
    //hashData := ethCrypto.Keccak256(ethMessage)
    flag := 0

    for {
        fmt.Printf("flag %d \n", flag)

        e := p.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, pkcs11.ObjectHandle(pvKeyId))
        if e != nil {
            log.Fatalf("failed to sign: %s\n", e)
            return nil, err
        }

        hsmSig, e := p.Sign(session, hashData)
        if e != nil {
            log.Fatalf("failed to sign: %s\n", e)
        }

        r := new(big.Int).SetBytes(hsmSig[:len(hsmSig)/2])
        s := new(big.Int).SetBytes(hsmSig[len(hsmSig)/2:])

        parsedSig, err := asn1.Marshal(ecdsaSignature{r, s})
        if err != nil {
            panic("ans1 unmarshal esig error: " + err.Error())
            return nil, err
        }
        //fmt.Printf("ECDSA Signature (%d bytes): %x\n", len(parsedSig), parsedSig)

        var esig ecdsaSignature
        _, err = asn1.Unmarshal(parsedSig, &esig)
        if err != nil {
            panic("asn1 unmarshal error:" + err.Error())
        }
        var ethFormatSig []byte
        ethFormatSig = append(ethFormatSig, esig.R.Bytes()...)
        ethFormatSig = append(ethFormatSig, esig.S.Bytes()...)
        //fmt.Println(hexutil.Encode(ethFormatSig))

        valid := ethCrypto.VerifySignature(hsmPubkey, hashData, ethFormatSig)
        if !valid {
            flag++
            continue
        } else {
            sigWith0 := append(ethFormatSig, byte(0))
            _pubkey, eErr := ethCrypto.Ecrecover(hashData, sigWith0)
            if eErr != nil {
                log.Fatalf("ecrCover0 error: %s", eErr)
                return nil, eErr
            }
            if bytes.Equal(_pubkey, hsmPubkey) {
                log.Printf(hexutil.Encode(sigWith0))
                return sigWith0, nil
            }
            sigWith1 := append(ethFormatSig, byte(1))
            _pubkey, err = ethCrypto.Ecrecover(hashData, sigWith1)
            if err != nil {
                log.Fatalf("ecrCover1 error: %s", eErr)
                return nil, err
            }
            if bytes.Equal(_pubkey, hsmPubkey) {
                log.Printf(hexutil.Encode(sigWith1))
                return sigWith1, nil
            }
            fmt.Printf("sign error with 0 | 1 v value")
            return nil, errors.New("sign error with 0 | 1 v value")
        }
    }
}

func SignTx(p *pkcs11.Ctx, session pkcs11.SessionHandle, chainId *big.Int, pbKeyId, pvKeyId int, tx *types.Transaction) (*types.Transaction, error) {
    signer := types.NewLondonSigner(chainId)
    hash := signer.Hash(tx)
    sig, err := HsmSign(p, session, pbKeyId, pvKeyId, hash[:])
    if err != nil {
        return nil, err
    }
    logrus.Info("tx with signature: ", common.Bytes2Hex(sig))
    return tx.WithSignature(signer, sig)
}

func main() {
    // Open PKCS#11 library
    lib := path
    p := pkcs11.New(lib)
    if p == nil {
        panic("Failed to load PKCS#11 library")
    }
    defer p.Destroy()

    // Initialize PKCS#11 library
    err := p.Initialize()
    if err != nil {
        panic(fmt.Sprintf("Failed to initialize PKCS#11 library: %s", err))
    }
    defer p.Finalize()

    // Find the first slot
    slots, err := p.GetSlotList(true)
    if err != nil {
        panic(fmt.Sprintf("Failed to get slots: %s", err))
    }
    if len(slots) == 0 {
        panic("No slots found")
    }
    slot := slots[0]
    fmt.Println("slot: ", slot)

    // Open a session
    session, err := p.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
    if err != nil {
        panic(fmt.Sprintf("Failed to open session: %s", err))
    }
    defer p.CloseSession(session)

    // Login
    err = p.Login(session, pkcs11.CKU_USER, "user:password")
    if err != nil {
        panic(fmt.Sprintf("Failed to login: %s", err))
    }
    defer p.Logout(session)

    pbKeyId := 262159
    pvKeyId := 262158
    addr := common.HexToAddress("0x4AAFD7998d6e3a6359e3b7FE02b8c24Ea6D1944D")
    // wrap tx
    ec, err := ethclient.Dial("https://goerli.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161")
    if err != nil {
        log.Fatal(err)
    }
    nonce, err := ec.PendingNonceAt(context.Background(), addr)
    if err != nil {
        log.Fatal(err)
    }
    log.Println("nonce: ", nonce)
    chainID, err := ec.NetworkID(context.Background())
    if err != nil {
        log.Fatal(err)
    }
    log.Println("chain id:", chainID)
    gasPrice, err := ec.SuggestGasPrice(context.Background())
    if err != nil {
        log.Fatal(err)
    }
    gasPrice = gasPrice.Mul(gasPrice, big.NewInt(3))
    tx := types.NewTransaction(nonce, addr, big.NewInt(2000000000), 21000, gasPrice, nil)
    signedTx, err := SignTx(p, session, chainID, pbKeyId, pvKeyId, tx)
    if err != nil {
        log.Fatal(err)
    }
    err = ec.SendTransaction(context.Background(), signedTx)
    if err != nil {
        panic(err)
    }

    binary, err := signedTx.MarshalBinary()
    if err != nil {
        panic(err)
    }
    hexutil.Encode(binary)
    fmt.Println("tx hash:", signedTx.Hash())
}