dajiaji / pyhpke

A Python Implementation of HPKE (Hybrid Public Key Encryption)
MIT License
7 stars 4 forks source link

Implement DeriveKeyPair to KEM. #9

Closed dajiaji closed 1 year ago

Aurvandill commented 1 year ago

i implemented DeriveKeyPair in a somewhat akward way. maybe it'll help you :)


def derive_key_pair(ikm: bytes, kem, kdf) -> KEMInterface:
    suite_id = bytearray(b"KEM")
    suite_id.extend(kem.id.value.to_bytes(2,"big"))
    kdf._suite_id = bytes(suite_id)

    match kem.id:
        case KEMId.DHKEM_P256_HKDF_SHA256 | KEMId.DHKEM_P384_HKDF_SHA384:
            sk_raw = _ec_derive_key_pair(ikm, kdf, kem)
        case KEMId.DHKEM_P521_HKDF_SHA512:
            kem._nsecret = 66
            sk_raw = _ec_derive_key_pair(ikm, kdf, kem)
        case KEMId.DHKEM_X25519_HKDF_SHA256:
            sk_raw = _x_derive_key_pair(ikm, kdf, kem)
        case KEMId.DHKEM_X448_HKDF_SHA512:
            kem._nsecret = 56
            sk_raw = _x_derive_key_pair(ikm, kdf, kem)
        case _:
            raise ValueError("could not derive secret key")

    # return the kemkeyinterface of the deserialized private key.
    return kem.deserialize_private_key(sk_raw)

def _x_derive_key_pair(ikm: bytes, kdf: KDF, kem: KEMInterface) -> bytes:
    dkp_prk = kdf.labeled_extract(b"", b"dkp_prk", ikm)
    sk = kdf.labeled_expand(dkp_prk, b"sk", b"", kem._nsecret)
    return sk

def _ec_derive_key_pair(ikm: bytes, kdf: KDF, kem: KEMInterface) -> bytes:
    # see https://www.rfc-editor.org/rfc/rfc9180#section-7.1.3-4

    dkp_prk = kdf.labeled_extract(b"", b"dkp_prk", ikm)
    match kem.id:
        case KEMId.DHKEM_P256_HKDF_SHA256:
            bitmask = 0xFF
            order = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
        case KEMId.DHKEM_P384_HKDF_SHA384:
            bitmask = 0xFF
            order = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC7634D81F4372DDF581A0DB248B0A77AECEC196ACCC52973
        case KEMId.DHKEM_P521_HKDF_SHA512:
            bitmask = 0x01
            order = 0x01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFA51868783BF2F966B7FCC0148F709A5D03BB5C9B8899C47AEBB6FB71E91386409
        case _:
            raise ValueError(f"Unknown KEMid {kem.id}")

    sk = 0
    counter = 0
    while sk == 0 or sk >= order:
        if counter > 255:
            raise ValueError("could not derive keypair")
        raw_key = bytearray(
            kdf.labeled_expand(dkp_prk, b"candidate", I2OSP(counter, 1), kem._nsecret)
        )

        raw_key[0] = raw_key[0] & bitmask
        sk = OS2IP(raw_key)
        counter = counter + 1
    return I2OSP(sk, kem._nsecret)

def I2OSP(n: int, w: int) -> bytes:
    if n < 0:
        raise ValueError("number must be positive")
    return n.to_bytes(w, "big")

def OS2IP(x: bytes) -> int:
    return int.from_bytes(x)

with best regards Aurvandill

dajiaji commented 1 year ago

@Aurvandill Thanks for the suggestion!

Sorry, I couldn't make time to implement it.

Could you send a PR for the issue?

I can review and merge it.

Aurvandill commented 1 year ago

will do so tomorrow :)

dajiaji commented 1 year ago

Resolved by #133