LeventErkok / sbv

SMT Based Verification in Haskell. Express properties about Haskell programs and automatically prove them using SMT solvers.
https://github.com/LeventErkok/sbv
Other
240 stars 33 forks source link

Make folds easier to use? #641

Closed LeventErkok closed 1 year ago

LeventErkok commented 1 year ago

Currently, the following works:

ghci> import Data.SBV
ghci> import Data.SBV.List
ghci> import Prelude hiding(length)
ghci> :set -XOverloadedLists
ghci> sat $ \x -> x .== sfoldl "(lambda ((x Int) (y Int)) (+ x y))" (0 :: SInteger) ([1,2,3 :: Integer])
ghci> sat $ \l -> 8 .== sfoldl "(lambda ((x Int) (y Int)) (+ x y))" 0 (l :: SList Integer) .&& length l .>= 3

But that lambda is ugly! Can we make this better somehow? Maybe a little local splice? Or something simpler but less error prone to use?

LeventErkok commented 1 year ago

Here's one poor-man's solution. It's nowhere where I want it to be; no binding, no arbitrary types (though can be extended to cover most), but here is a start:

module T where

import Prelude hiding(length)
import Data.SBV
import Data.SBV.List hiding ((++))

instance Num (String, Kind) where
  fromInteger i   = (show i, KUnbounded)
  (a, k) + (b, _) = ("(+ " ++ a ++ " " ++ b ++ ")", k)
  (a, k) - (b, _) = ("(- " ++ a ++ " " ++ b ++ ")", k)
  (a, k) * (b, _) = ("(* " ++ a ++ " " ++ b ++ ")", k)
  abs    _ = error "no abs"
  signum _ = error "no signum"

smtType :: Kind -> String
smtType KUnbounded = "Int"
smtType k          = error $ "don't now how to: " ++ show k

lambda :: [(String, Kind)] -> (String, Kind) -> String
lambda bs (body, _) = unlines $   "(lambda ("
                              :  ["   (" ++ v ++ " " ++ smtType k ++ ")" | (v, k) <- bs]
                              ++ ["   )"]
                              ++ ["   " ++ body ++ ")"]

t :: IO SatResult
t = satWith z3{verbose=True} $ do

       l :: SList Integer <- sList_

       let x = ("x", KUnbounded)
           y = ("y", KUnbounded)

       let total = sfoldl (lambda [x, y] (x+y)) (0 :: SInteger) l
       constrain $ length l .>= 3

       pure $ total .== 2
LeventErkok commented 1 year ago

Another somewhat fun but not really successful attempt, adding the following to Data.SBV.List.hs:

First export:

      -- * Experiment
     , smap2, Expr(..)

then import:

import Data.SBV.Core.Kind       (smtType)
import Data.SBV.Utils.PrettyNum (cvToSMTLib)

and then:

newtype LVar a = LVar String
newtype Expr a = Expr String

mkVar :: String -> (Expr a -> b) -> (LVar a, b)
mkVar n f = (LVar n, f (Expr n))

instance (SymVal a, Ord a, Num a) => Num (Expr a) where
  fromInteger i   = Expr $ cvToSMTLib RoundNearestTiesToEven (mkConstCV (kindOf (Proxy @a)) i)
  Expr a + Expr b = case kindOf (Proxy @a) of
                      KUnbounded -> Expr $ "(+ "     P.++ a P.++ " " P.++ b P.++ ")"
                      KBounded{} -> Expr $ "(bvadd " P.++ a P.++ " " P.++ b P.++ ")"
                      _ -> error "tbd"
  Expr a - Expr b = Expr $ "(- " P.++ a P.++ " " P.++ b P.++ ")"
  Expr a * Expr b = Expr $ "(* " P.++ a P.++ " " P.++ b P.++ ")"
  abs    _        = error "no abs"
  signum _        = error "no signum"

lambda :: forall a b. HasKind a => LVar a -> Expr b -> String
lambda (LVar s) (Expr body) = unlines
      [ "(lambda ("
      , "   (" P.++ s P.++ " " P.++ smtType (kindOf (Proxy @a)) P.++ ")"
      , "   )"
      , "   " P.++ body
      , ")"
      ]

smap2 :: (SymVal a, SymVal b) => (Expr a -> Expr b) -> SList a -> SList b
smap2 f l = SBV $ SVal k $ Right $ cache r
  where (a, e) = mkVar "x" f
        op = lambda a e

        k = kindOf l
        r st = do sva <- sbvToSV st l
                  newExpr st k (SBVApp (SeqOp (SeqMap op)) [sva])

{-
import Data.SBV.List
import Prelude ((+))
:set -XOverloadedLists
:set -XDataKinds
sat $ \l -> l .== smap2 (+1) [1,2,3::Integer]
sat $ \l -> l .== smap2 (+1) [1,2,3::WordN 8]
-- Can't type-check this. Why? Because:
--   SList (SList a) /= SList [[a]]
--   SList (SList a) *IS*  SBV [SBV [a]]
--                   *NOT* SBV [[a]]
-- I don't see how to make this work then.
sat $ \l -> l .== (smap2 (const (Expr "(seq.unit x)") :: Expr Integer -> Expr (SList Integer)) [1,2,3::Integer] :: SList [[Integer]])
-}
LeventErkok commented 1 year ago

This is now handled just fine!