Closed LeventErkok closed 1 year ago
Here's a basic implementation. This one doesn't do the "endianness" business, but handles 1 and 2 above:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ParallelListComp #-}
{-# OPTIONS_GHC -Wall -Werror -Wno-incomplete-uni-patterns #-}
module AES(cgAESLibrary
-- Testing
, runTests
-- Test vectors
, aes128Key, aes192Key, aes256Key, aes128InvKey, aes192InvKey, aes192InvKeyExtended, aes256InvKey
, commonPT, aes128CT, aes192CT, aes256CT
-- Helper
, chop4
) where
import Control.Monad
import Data.SBV
import Data.SBV.Tools.CodeGen
import Documentation.SBV.Examples.Crypto.AES hiding (aesLibComponents, cgAESLibrary, aes128IsCorrect, u0, u1, u2, u3, u0Func)
import Test.QuickCheck hiding (classify, verbose)
import Data.List
import Data.Maybe
data Func = SBVEnc
| SBVDec
| SBVKey
| SBVInvKey
cgAESLibrary :: Int -> Maybe FilePath -> IO ()
cgAESLibrary sz mbd
| sz `elem` [128, 192, 256] = void $ compileToCLib mbd nm comps
| True = error $ "cgAESLibrary: Size must be one of 128, 192, or 256, received: " ++ show sz
where nm = "aes" ++ show sz ++ "Lib"
comps = [(fnm, configure fnm code) | (fnm, code) <- aesLibComponents sz]
configure fnm code = do cgOverwriteFiles True
cgSetDriverValues (drivers (classify fnm))
code
classify :: String -> Func
classify s
| "BlockEncrypt" `isSuffixOf` s = SBVEnc
| "BlockDecrypt" `isSuffixOf` s = SBVDec
| "InvKeySchedule" `isSuffixOf` s = SBVInvKey
| "KeySchedule" `isSuffixOf` s = SBVKey
| True = error $ "Can't classify the function: " ++ show s
-- Instantiate with known-test vectors
drivers :: Func -> [Integer]
drivers f = cvt $ case sz of
128 -> aes128 f
192 -> aes192 f
256 -> aes256 f
_ -> error $ "Unexpected size: " ++ show sz
aes128 SBVEnc = encDriver commonPT aes128Key
aes128 SBVDec = decDriver aes128CT aes128InvKey
aes128 SBVKey = keyDriver aes128Key
aes128 SBVInvKey = keyDriver aes128InvKey
aes192 SBVEnc = encDriver commonPT aes192Key
aes192 SBVDec = decDriver aes192CT aes192InvKey
aes192 SBVKey = keyDriver aes192Key
aes192 SBVInvKey = keyDriver aes192InvKey
aes256 SBVEnc = encDriver commonPT aes256Key
aes256 SBVDec = decDriver aes256CT aes256InvKey
aes256 SBVKey = keyDriver aes256Key
aes256 SBVInvKey = keyDriver aes256InvKey
flatten (f, mid, l) = f ++ concat mid ++ l
cvt = map (fromIntegral . fromJust . unliteral)
encDriver curPT key = curPT ++ flatten (fst (aesKeySchedule key))
decDriver curCT key = curCT ++ flatten (aesInvKeySchedule key)
keyDriver key = concatMap reverse . chop4 $ key
aesLibComponents :: Int -> [(String, SBVCodeGen ())]
aesLibComponents sz = [ ("aes" ++ show sz ++ "KeySchedule", keySchedule)
, ("aes" ++ show sz ++ "InvKeySchedule", invKeySchedule)
, ("aes" ++ show sz ++ "BlockEncrypt", enc)
, ("aes" ++ show sz ++ "BlockDecrypt", dec)
]
where -- key-schedule
nk
| sz == 128 = 4
| sz == 192 = 6
| sz == 256 = 8
| True = error $ "aesLibComponents: Size must be one of 128, 192, or 256; received: " ++ show sz
-- We get 4*(nr+1) keys, where nr = nk + 6
nr = nk + 6
xk = 4 * (nr + 1)
keySchedule = do key <- cgInputArr nk "key" -- key
let encKS = fst $ aesKeySchedule key
cgOutputArr "encKS" (ksToXKey encKS)
invKeySchedule = do key <- cgInputArr nk "key" -- key
let decKS = aesInvKeySchedule (concatMap reverse (chop4 key))
cgOutputArr "decKS" (ksToXKey decKS)
-- encryption
enc = do pt <- cgInputArr 4 "pt" -- plain-text
xkey <- cgInputArr xk "xkey" -- expanded key
cgOutputArr "ct" $ aesEncrypt pt (xkeyToKS xkey)
-- decryption
dec = do ct <- cgInputArr 4 "ct" -- cipher-text
xkey <- cgInputArr xk "xkey" -- expanded key
cgOutputArr "pt" $ aesDecryptUnwoundKey ct (xkeyToKS xkey)
-- Turn a KS to a series of expanded key words
ksToXKey :: KS -> [SWord 32]
ksToXKey (f, m, l) = f ++ concat m ++ l
-- Transforming back and forth from our KS type to a flat array used by the generated C code
-- Turn a series of expanded keys to our internal KS type
xkeyToKS :: [SWord 32] -> KS
xkeyToKS xs = (f, m, l)
where f = take 4 xs -- first round key
m = chop4 (take (xk - 8) (drop 4 xs)) -- middle rounds
l = drop (xk - 4) xs -- last round key
commonPT, aes128CT, aes192CT, aes256CT, aes128Key, aes192Key, aes256Key :: Key
commonPT = [0x00112233, 0x44556677, 0x8899aabb, 0xccddeeff]
aes128CT = [0x69c4e0d8, 0x6a7b0430, 0xd8cdb780, 0x70b4c55a]
aes192CT = [0xdda97ca4, 0x864cdfe0, 0x6eaf70a0, 0xec0d7191]
aes256CT = [0x8ea2b7ca, 0x516745bf, 0xeafc4990, 0x4b496089]
aes128Key = [0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f]
aes192Key = aes128Key ++ [0x10111213, 0x14151617]
aes256Key = aes192Key ++ [0x18191a1b, 0x1c1d1e1f]
-- Inverse keys:
aes128InvKey, aes192InvKey, aes192InvKeyExtended, aes256InvKey :: Key
aes128InvKey = extractFinalKey aes128Key
aes192InvKey = extractFinalKey aes192Key
aes192InvKeyExtended = extractFinalKeyExtended aes192Key
aes256InvKey = extractFinalKey aes256Key
extractFinalKey :: [SWord 32] -> [SWord 32]
extractFinalKey initKey = take nk (extractFinalKeyExtended initKey)
where nk = length initKey
extractFinalKeyExtended :: [SWord 32] -> [SWord 32]
extractFinalKeyExtended initKey = take feed (concatMap reverse (chop4 (take feed roundKeys)))
where nk = length initKey
feed | nk == 4 = 4
| True = 8
((f, m, l), _) = aesKeySchedule initKey
roundKeys = l ++ concat (reverse m) ++ f
invKeyExpansion :: Int -> Key -> [Key]
invKeyExpansion nk rkey = map reverse (chop4 keys)
where keys :: [SWord 32]
keys = rkey ++ [invNextWord i prev old | i <- reverse [0 .. remaining - 1] | prev <- drop 1 keys | old <- keys]
totalWords = 4 * (nk + 6 + 1)
remaining = totalWords - nk
invNextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
invNextWord i prev old
| i `mod` nk == 0 = old `xor` subWordRcon (prev `rotateL` 8) (roundConstants !! (1 + i `div` nk))
| i `mod` nk == 4 && nk > 6 = old `xor` subWordRcon prev 0
| True = old `xor` prev
subWordRcon :: SWord 32 -> GF28 -> SWord 32
subWordRcon w rc = fromBytes [a `xor` rc, b, c, d]
where [a, b, c, d] = map sbox $ toBytes w
aesInvKeySchedule :: Key -> KS
aesInvKeySchedule key
| nk `elem` [4, 6, 8]
= decKS
| True
= error "aesInvKeySchedule: Invalid key size"
where nk = length key
nr = nk + 6
decKS = (head rKeys, take (nr-1) (tail rKeys), rKeys !! nr)
rKeys = invKeyExpansion nk key
-- | Block decryption, starting from the unwound key. That is, start from the final key.
-- Also; we don't use the T-box implementation. Just pure AES inverse cipher.
aesDecryptUnwoundKey :: [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey ct decKS
| length ct == 4
= doRounds aesInvRoundRegular decKS ct
| True
= error "aesDecrypt: Invalid cipher-text size"
where aesInvRoundRegular isFinal s key = u
where u :: State
u = map (f isFinal) [0 .. 3]
where a = map toBytes s
kbs = map toBytes key
f True j = fromBytes [ unSBox (a !! ((j+0) `mod` 4) !! 0)
, unSBox (a !! ((j+3) `mod` 4) !! 1)
, unSBox (a !! ((j+2) `mod` 4) !! 2)
, unSBox (a !! ((j+1) `mod` 4) !! 3)
] `xor` (key !! j)
f False j = e0 `xor` e1 `xor` e2 `xor` e3
where e0 = u0 $ unSBox (a !! ((j+0) `mod` 4) !! 0) `xor` (kbs !! j !! 0)
e1 = u1 $ unSBox (a !! ((j+3) `mod` 4) !! 1) `xor` (kbs !! j !! 1)
e2 = u2 $ unSBox (a !! ((j+2) `mod` 4) !! 2) `xor` (kbs !! j !! 2)
e3 = u3 $ unSBox (a !! ((j+1) `mod` 4) !! 3) `xor` (kbs !! j !! 3)
-- | T-box table generating function for decryption
u0Func :: GF28 -> [GF28]
u0Func s = [s `gf28Mult` 0xE, s `gf28Mult` 0x9, s `gf28Mult` 0xD, s `gf28Mult` 0xB]
-- | First look-up table used in decryption
u0 :: GF28 -> SWord 32
u0 = select t0Table 0 where t0Table = [fromBytes (u0Func a) | a <- [0..255]]
-- | Second look-up table used in decryption
u1 :: GF28 -> SWord 32
u1 = select t1Table 0 where t1Table = [fromBytes (u0Func a `rotR` 1) | a <- [0..255]]
-- | Third look-up table used in decryption
u2 :: GF28 -> SWord 32
u2 = select t2Table 0 where t2Table = [fromBytes (u0Func a `rotR` 2) | a <- [0..255]]
-- | Fourth look-up table used in decryption
u3 :: GF28-> SWord 32
u3 = select t3Table 0 where t3Table = [fromBytes (u0Func a `rotR` 3) | a <- [0..255]]
aes128K, aes192K, aes256K :: Key
aes128K = [0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f]
aes192K = aes128K ++ [0x10111213, 0x14151617]
aes256K = aes192K ++ [0x18191a1b, 0x1c1d1e1f]
chop4 :: [a] -> [[a]]
chop4 [] = []
chop4 xs = take 4 xs : chop4 (drop 4 xs)
runTests :: IO ()
runTests = do testInvKeyExpansion
check "AES128" aes128K aes128InvKey aes128CT
check "AES192" aes192K aes192InvKey aes192CT
check "AES256" aes256K aes256InvKey aes256CT
putStrLn "Quick-check AES128 roundtrip" >> quickCheck aes128IsCorrect
putStrLn "Quick-check AES192 roundtrip" >> quickCheck aes192IsCorrect
putStrLn "Quick-check AES256 roundtrip" >> quickCheck aes256IsCorrect
where check :: String -> Key -> Key -> [SWord 32] -> IO ()
check what key invKey ctExpected = do eq ("Encryption " ++ what) ctExpected ctGot
eq ("Decryption " ++ what) commonPT ptGot
eq ("Decryption-OTF " ++ what) commonPT ptGotInv
where (encKS, decKS) = aesKeySchedule key
ctGot = aesEncrypt commonPT encKS
ptGot = aesDecrypt ctExpected decKS
ptGotInv = aesDecryptUnwoundKey ctExpected (aesInvKeySchedule invKey)
eq tag expected got
| length expected /= length got
= error $ unlines [ "BAD!: " ++ tag
, "Comparing different sized lists:"
, "Expected: " ++ show expected
, "Got : " ++ show got
]
| map extract expected == map extract got
= putStrLn $ "GOOD: " ++ tag
| True
= error $ unlines [ "BAD!: " ++ tag
, "Expected: " ++ unwords (map hex8 expected)
, "Got : " ++ unwords (map hex8 got)
]
where extract x = case unliteral x of
Just v -> v
Nothing -> error $ "Can't extract value from: " ++ show x
testInvKeyExpansion :: IO ()
testInvKeyExpansion = do goTestInvKey "128" aes128K
goTestInvKey "192" aes192K
goTestInvKey "256" aes256K
goTestInvKey what k = do
let nk = length k
nr = nk + 6
feed = case nk of
4 -> 4
_ -> 8
((f, m, l), _) = aesKeySchedule k
required = l ++ concat (reverse m) ++ f
invKeySchedule = take (nr+1) $ invKeyExpansion nk (take nk (concatMap reverse (chop4 (take feed required))))
obtained = concat invKeySchedule
unlit x = case unliteral x of
Just v -> v
Nothing -> error $ "Can't unliteral: " ++ show x
expected = map unlit required
result = map unlit obtained
sh (i::Int) a b
| a == b = pad ++ show i ++ " " ++ disp a
| True = pad ++ show i ++ " " ++ disp a ++ " |vs| " ++ disp b
where pad = if i < 10 then " " else ""
disp = unwords . map (hex8 . literal)
lexpected = length expected
lresult = length result
when (lexpected /= lresult) $
error $ what ++ ": BAD! Mismatching lengths: " ++ show (lexpected, lresult)
let verbose = False
if expected == result
then if verbose
then putStrLn $ unlines $ ("Size " ++ what ++ ": Good") : zipWith3 sh [0..] (chop4 expected) (chop4 result)
else putStrLn $ "GOOD: Key generation AES" ++ what
else error $ unlines $ ("Size " ++ what ++ ": BAD!") : zipWith3 sh [0..] (chop4 expected) (chop4 result)
aes128IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3) = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3]
aes192IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3, k4, k5) = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3, k4, k5]
aes256IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3, k4, k5, k6, k7) = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3, k4, k5, k6, k7]
roundTrip :: [SWord32] -> [SWord32] -> SBool
roundTrip ptIn keyIn = pt .== pt' .&& pt .== pt''
where pt = map toSized ptIn
key = map toSized keyIn
(encKS, decKS) = aesKeySchedule key
ct = aesEncrypt pt encKS
pt' = aesDecrypt ct decKS
pt'' = aesDecryptUnwoundKey ct (aesInvKeySchedule (extractFinalKey key))
All improvements done, except for endianness split in C vs Haskell. I think this is an OK compromise: The C-library will need a driver anyhow, and the users can adjust accordingly.