LeventErkok / sbv

SMT Based Verification in Haskell. Express properties about Haskell programs and automatically prove them using SMT solvers.
https://github.com/LeventErkok/sbv
Other
240 stars 33 forks source link

AES improvements #638

Closed LeventErkok closed 1 year ago

LeventErkok commented 1 year ago
  1. Current version does alternative decryption method. But it doesn't do OTF-decryption, starting from unwound key. Add that too.
  2. Change code generator so it puts in the AES test vectors as the drivers
  3. The endianness of the input in Haskell vs regular C is annoyingly reversing. Can we change the code generator setup so it does the more "regular" thing, without needing to reverse?
LeventErkok commented 1 year ago

Here's a basic implementation. This one doesn't do the "endianness" business, but handles 1 and 2 above:

{-# LANGUAGE    DataKinds        #-}
{-# LANGUAGE    ParallelListComp #-}

{-# OPTIONS_GHC -Wall -Werror -Wno-incomplete-uni-patterns #-}

module AES(cgAESLibrary

           -- Testing
           , runTests

           -- Test vectors
           , aes128Key, aes192Key, aes256Key, aes128InvKey, aes192InvKey, aes192InvKeyExtended, aes256InvKey
           , commonPT, aes128CT, aes192CT, aes256CT

           -- Helper
           , chop4
           ) where

import Control.Monad

import Data.SBV
import Data.SBV.Tools.CodeGen

import Documentation.SBV.Examples.Crypto.AES hiding (aesLibComponents, cgAESLibrary, aes128IsCorrect, u0, u1, u2, u3, u0Func)

import Test.QuickCheck hiding (classify, verbose)

import Data.List
import Data.Maybe

data Func = SBVEnc
          | SBVDec
          | SBVKey
          | SBVInvKey

cgAESLibrary :: Int -> Maybe FilePath -> IO ()
cgAESLibrary sz mbd
  | sz `elem` [128, 192, 256] = void $ compileToCLib mbd nm comps
  | True                      = error $ "cgAESLibrary: Size must be one of 128, 192, or 256, received: " ++ show sz
  where nm = "aes" ++ show sz ++ "Lib"

        comps = [(fnm, configure fnm code) | (fnm, code) <- aesLibComponents sz]

        configure fnm code = do cgOverwriteFiles True
                                cgSetDriverValues (drivers (classify fnm))
                                code

        classify :: String -> Func
        classify s
          | "BlockEncrypt"   `isSuffixOf` s = SBVEnc
          | "BlockDecrypt"   `isSuffixOf` s = SBVDec
          | "InvKeySchedule" `isSuffixOf` s = SBVInvKey
          | "KeySchedule"    `isSuffixOf` s = SBVKey
          | True                            = error $ "Can't classify the function: " ++ show s

        -- Instantiate with known-test vectors
        drivers :: Func -> [Integer]
        drivers f = cvt $ case sz of
                            128 -> aes128 f
                            192 -> aes192 f
                            256 -> aes256 f
                            _   -> error $ "Unexpected size: " ++ show sz

        aes128 SBVEnc    = encDriver commonPT aes128Key
        aes128 SBVDec    = decDriver aes128CT aes128InvKey
        aes128 SBVKey    = keyDriver          aes128Key
        aes128 SBVInvKey = keyDriver          aes128InvKey

        aes192 SBVEnc    = encDriver commonPT aes192Key
        aes192 SBVDec    = decDriver aes192CT aes192InvKey
        aes192 SBVKey    = keyDriver          aes192Key
        aes192 SBVInvKey = keyDriver          aes192InvKey

        aes256 SBVEnc    = encDriver commonPT aes256Key
        aes256 SBVDec    = decDriver aes256CT aes256InvKey
        aes256 SBVKey    = keyDriver          aes256Key
        aes256 SBVInvKey = keyDriver          aes256InvKey

        flatten (f, mid, l) = f ++ concat mid ++ l
        cvt = map (fromIntegral . fromJust . unliteral)

        encDriver curPT key = curPT ++ flatten (fst (aesKeySchedule key))
        decDriver curCT key = curCT ++ flatten (aesInvKeySchedule key)
        keyDriver       key = concatMap reverse . chop4 $ key

aesLibComponents :: Int -> [(String, SBVCodeGen ())]
aesLibComponents sz = [ ("aes" ++ show sz ++ "KeySchedule",    keySchedule)
                      , ("aes" ++ show sz ++ "InvKeySchedule", invKeySchedule)
                      , ("aes" ++ show sz ++ "BlockEncrypt",   enc)
                      , ("aes" ++ show sz ++ "BlockDecrypt",   dec)
                      ]
  where -- key-schedule
        nk
         | sz == 128 = 4
         | sz == 192 = 6
         | sz == 256 = 8
         | True      = error $ "aesLibComponents: Size must be one of 128, 192, or 256; received: " ++ show sz

        -- We get 4*(nr+1) keys, where nr = nk + 6
        nr = nk + 6
        xk = 4 * (nr + 1)

        keySchedule    = do key <- cgInputArr nk "key"     -- key
                            let encKS = fst $ aesKeySchedule key
                            cgOutputArr "encKS" (ksToXKey encKS)

        invKeySchedule = do key <- cgInputArr nk "key"     -- key
                            let decKS = aesInvKeySchedule (concatMap reverse (chop4 key))
                            cgOutputArr "decKS" (ksToXKey decKS)

        -- encryption
        enc = do pt   <- cgInputArr 4  "pt"    -- plain-text
                 xkey <- cgInputArr xk "xkey"  -- expanded key
                 cgOutputArr "ct" $ aesEncrypt pt (xkeyToKS xkey)

        -- decryption
        dec = do ct   <- cgInputArr 4  "ct"    -- cipher-text
                 xkey <- cgInputArr xk "xkey"  -- expanded key
                 cgOutputArr "pt" $ aesDecryptUnwoundKey ct (xkeyToKS xkey)

        -- Turn a KS to a series of expanded key words
        ksToXKey :: KS -> [SWord 32]
        ksToXKey (f, m, l) = f ++ concat m ++ l

        -- Transforming back and forth from our KS type to a flat array used by the generated C code
        -- Turn a series of expanded keys to our internal KS type
        xkeyToKS :: [SWord 32] -> KS
        xkeyToKS xs = (f, m, l)
           where f  = take 4 xs                             -- first round key
                 m  = chop4 (take (xk - 8) (drop 4 xs))     -- middle rounds
                 l  = drop (xk - 4) xs                      -- last round key

commonPT, aes128CT, aes192CT, aes256CT, aes128Key, aes192Key, aes256Key :: Key
commonPT = [0x00112233, 0x44556677, 0x8899aabb, 0xccddeeff]
aes128CT = [0x69c4e0d8, 0x6a7b0430, 0xd8cdb780, 0x70b4c55a]
aes192CT = [0xdda97ca4, 0x864cdfe0, 0x6eaf70a0, 0xec0d7191]
aes256CT = [0x8ea2b7ca, 0x516745bf, 0xeafc4990, 0x4b496089]
aes128Key =              [0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f]
aes192Key = aes128Key ++ [0x10111213, 0x14151617]
aes256Key = aes192Key ++ [0x18191a1b, 0x1c1d1e1f]

-- Inverse keys:
aes128InvKey, aes192InvKey, aes192InvKeyExtended, aes256InvKey :: Key
aes128InvKey         = extractFinalKey         aes128Key
aes192InvKey         = extractFinalKey         aes192Key
aes192InvKeyExtended = extractFinalKeyExtended aes192Key
aes256InvKey         = extractFinalKey         aes256Key

extractFinalKey :: [SWord 32] -> [SWord 32]
extractFinalKey initKey = take nk (extractFinalKeyExtended initKey)
  where nk = length initKey

extractFinalKeyExtended :: [SWord 32] -> [SWord 32]
extractFinalKeyExtended initKey = take feed (concatMap reverse (chop4 (take feed roundKeys)))
  where nk             = length initKey
        feed | nk == 4 = 4
             | True    = 8

        ((f, m, l), _) = aesKeySchedule initKey
        roundKeys      = l ++ concat (reverse m) ++ f

invKeyExpansion :: Int -> Key -> [Key]
invKeyExpansion nk rkey = map reverse (chop4 keys)
   where keys :: [SWord 32]
         keys = rkey ++ [invNextWord i prev old | i <- reverse [0 .. remaining - 1] | prev <- drop 1 keys | old <- keys]

         totalWords = 4 * (nk + 6 + 1)
         remaining  = totalWords - nk

         invNextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
         invNextWord i prev old
           | i `mod` nk == 0           = old `xor` subWordRcon (prev `rotateL` 8) (roundConstants !! (1 + i `div` nk))
           | i `mod` nk == 4 && nk > 6 = old `xor` subWordRcon prev 0
           | True                      = old `xor` prev

         subWordRcon :: SWord 32 -> GF28 -> SWord 32
         subWordRcon w rc = fromBytes [a `xor` rc, b, c, d]
            where [a, b, c, d] = map sbox $ toBytes w

aesInvKeySchedule :: Key -> KS
aesInvKeySchedule key
  | nk `elem` [4, 6, 8]
  = decKS
  | True
  = error "aesInvKeySchedule: Invalid key size"
  where nk = length key
        nr = nk + 6
        decKS = (head rKeys, take (nr-1) (tail rKeys), rKeys !! nr)
        rKeys = invKeyExpansion nk key

-- | Block decryption, starting from the unwound key. That is, start from the final key.
-- Also; we don't use the T-box implementation. Just pure AES inverse cipher.
aesDecryptUnwoundKey :: [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey ct decKS
  | length ct == 4
  = doRounds aesInvRoundRegular decKS ct
  | True
  = error "aesDecrypt: Invalid cipher-text size"
  where aesInvRoundRegular isFinal s key = u
          where u :: State
                u = map (f isFinal) [0 .. 3]
                  where a   = map toBytes s
                        kbs = map toBytes key
                        f True j = fromBytes [ unSBox (a !! ((j+0) `mod` 4) !! 0)
                                             , unSBox (a !! ((j+3) `mod` 4) !! 1)
                                             , unSBox (a !! ((j+2) `mod` 4) !! 2)
                                             , unSBox (a !! ((j+1) `mod` 4) !! 3)
                                             ] `xor` (key !! j)
                        f False j = e0 `xor` e1 `xor` e2 `xor` e3
                              where e0 = u0 $ unSBox (a !! ((j+0) `mod` 4) !! 0) `xor` (kbs !! j !! 0)
                                    e1 = u1 $ unSBox (a !! ((j+3) `mod` 4) !! 1) `xor` (kbs !! j !! 1)
                                    e2 = u2 $ unSBox (a !! ((j+2) `mod` 4) !! 2) `xor` (kbs !! j !! 2)
                                    e3 = u3 $ unSBox (a !! ((j+1) `mod` 4) !! 3) `xor` (kbs !! j !! 3)

-- | T-box table generating function for decryption
u0Func :: GF28 -> [GF28]
u0Func s = [s `gf28Mult` 0xE, s `gf28Mult` 0x9, s `gf28Mult` 0xD, s `gf28Mult` 0xB]

-- | First look-up table used in decryption
u0 :: GF28 -> SWord 32
u0 = select t0Table 0 where t0Table = [fromBytes (u0Func a)          | a <- [0..255]]

-- | Second look-up table used in decryption
u1 :: GF28 -> SWord 32
u1 = select t1Table 0 where t1Table = [fromBytes (u0Func a `rotR` 1) | a <- [0..255]]

-- | Third look-up table used in decryption
u2 :: GF28 -> SWord 32
u2 = select t2Table 0 where t2Table = [fromBytes (u0Func a `rotR` 2) | a <- [0..255]]

-- | Fourth look-up table used in decryption
u3 :: GF28-> SWord 32
u3 = select t3Table 0 where t3Table = [fromBytes (u0Func a `rotR` 3) | a <- [0..255]]

aes128K, aes192K, aes256K :: Key
aes128K = [0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f]
aes192K = aes128K ++ [0x10111213, 0x14151617]
aes256K = aes192K ++ [0x18191a1b, 0x1c1d1e1f]

chop4 :: [a] -> [[a]]
chop4 [] = []
chop4 xs = take 4 xs : chop4 (drop 4 xs)

runTests :: IO ()
runTests = do testInvKeyExpansion

              check "AES128" aes128K aes128InvKey aes128CT
              check "AES192" aes192K aes192InvKey aes192CT
              check "AES256" aes256K aes256InvKey aes256CT

              putStrLn "Quick-check AES128 roundtrip" >> quickCheck aes128IsCorrect
              putStrLn "Quick-check AES192 roundtrip" >> quickCheck aes192IsCorrect
              putStrLn "Quick-check AES256 roundtrip" >> quickCheck aes256IsCorrect

  where check :: String -> Key -> Key -> [SWord 32] -> IO ()
        check what key invKey ctExpected = do eq ("Encryption     " ++ what) ctExpected ctGot
                                              eq ("Decryption     " ++ what) commonPT   ptGot
                                              eq ("Decryption-OTF " ++ what) commonPT   ptGotInv
           where (encKS, decKS) = aesKeySchedule key
                 ctGot          = aesEncrypt            commonPT   encKS
                 ptGot          = aesDecrypt            ctExpected decKS
                 ptGotInv       = aesDecryptUnwoundKey  ctExpected (aesInvKeySchedule invKey)

                 eq tag expected got
                   | length expected /= length got
                   = error $ unlines [ "BAD!: " ++ tag
                                     , "Comparing different sized lists:"
                                     , "Expected: " ++ show expected
                                     , "Got     : " ++ show got
                                     ]
                   | map extract expected == map extract got
                   = putStrLn $ "GOOD: " ++ tag
                   | True
                   = error $ unlines [ "BAD!: " ++ tag
                                     , "Expected: " ++ unwords (map hex8 expected)
                                     , "Got     : " ++ unwords (map hex8 got)
                                     ]
                  where extract x = case unliteral x of
                                      Just v  -> v
                                      Nothing -> error $ "Can't extract value from: " ++ show x

        testInvKeyExpansion :: IO ()
        testInvKeyExpansion = do goTestInvKey "128" aes128K
                                 goTestInvKey "192" aes192K
                                 goTestInvKey "256" aes256K
        goTestInvKey what k = do
          let nk = length k
              nr = nk + 6

              feed = case nk of
                       4 -> 4
                       _ -> 8

              ((f, m, l), _) = aesKeySchedule k
              required       = l ++ concat (reverse m) ++ f
              invKeySchedule = take (nr+1) $ invKeyExpansion nk (take nk (concatMap reverse (chop4 (take feed required))))
              obtained       = concat invKeySchedule

              unlit x = case unliteral x of
                          Just v  -> v
                          Nothing -> error $ "Can't unliteral: " ++ show x

              expected = map unlit required
              result   = map unlit obtained

              sh (i::Int) a b
               | a == b = pad ++ show i ++ " " ++ disp a
               | True   = pad ++ show i ++ " " ++ disp a ++ " |vs| " ++ disp b
               where pad = if i < 10 then " " else ""

              disp = unwords . map (hex8 . literal)

              lexpected = length expected
              lresult   = length result

          when (lexpected /= lresult) $
             error $ what ++ ": BAD! Mismatching lengths: " ++ show (lexpected, lresult)

          let verbose = False

          if expected == result
             then if verbose
                     then putStrLn $ unlines $ ("Size " ++ what ++ ": Good") : zipWith3 sh [0..] (chop4 expected) (chop4 result)
                     else putStrLn $ "GOOD: Key generation AES" ++ what
             else error    $ unlines $ ("Size " ++ what ++ ": BAD!") : zipWith3 sh [0..] (chop4 expected) (chop4 result)

        aes128IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3)                 = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3]
        aes192IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3, k4, k5)         = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3, k4, k5]
        aes256IsCorrect (i0, i1, i2, i3) (k0, k1, k2, k3, k4, k5, k6, k7) = roundTrip [i0, i1, i2, i3] [k0, k1, k2, k3, k4, k5, k6, k7]

        roundTrip :: [SWord32] -> [SWord32] -> SBool
        roundTrip ptIn keyIn = pt .== pt' .&& pt .== pt''
           where pt  = map toSized ptIn
                 key = map toSized keyIn

                 (encKS, decKS) = aesKeySchedule key
                 ct   = aesEncrypt pt encKS
                 pt'  = aesDecrypt ct decKS
                 pt'' = aesDecryptUnwoundKey ct (aesInvKeySchedule (extractFinalKey key))
LeventErkok commented 1 year ago

All improvements done, except for endianness split in C vs Haskell. I think this is an OK compromise: The C-library will need a driver anyhow, and the users can adjust accordingly.