diff --git a/crucible-syntax/README.txt b/crucible-syntax/README.txt index 5be8e7c4f..731587571 100644 --- a/crucible-syntax/README.txt +++ b/crucible-syntax/README.txt @@ -79,8 +79,17 @@ first may be relaxed. Types -t ::= 'Any' | 'Unit' | 'Bool' | 'Nat' | 'Real' | 'ComplexReal' | 'Char' | 'String' - | '(' 'Vector' t ')' | '(' 'BitVector' n ')' | '(' '->' t ... t ')' +si ::= 'Unicode' | 'Char16' | 'Char8' + +fi ::= 'Half' | 'Float' | 'Double' | 'Quad' + | 'X86_80' | 'DoubleDouble' + +t ::= 'Any' | 'Unit' | 'Bool' | 'Nat' | 'Integer' | 'Real' + | 'ComplexReal' | 'Char' | '(' 'String' si ')' + | '(' 'FP' fi ')' | '(' 'BitVector' n ')' + | '(' '->' t ... t ')' | '(' 'Maybe' t ')' + | '(' 'Sequence' t ')' | '(' 'Vector' t ')' | '(' 'Ref' t ')' + | '(' 'Struct' t ... t ')' | '(' 'Variant' t ... t ')' Expressions diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/Atoms.hs b/crucible-syntax/src/Lang/Crucible/Syntax/Atoms.hs index dc589f1c2..376a46887 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/Atoms.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/Atoms.hs @@ -19,7 +19,6 @@ import Control.Applicative import Data.Char import Data.Functor import Data.Ratio -import Data.Semigroup ( (<>) ) import Data.Text (Text) import qualified Data.Text as T @@ -50,7 +49,7 @@ data Keyword = Defun | DefBlock | DefGlobal | Just_ | Nothing_ | FromJust | Inj | Proj | AnyT | UnitT | BoolT | NatT | IntegerT | RealT | ComplexRealT | CharT | StringT - | BitvectorT | VectorT | FPT | FunT | MaybeT | VariantT | RefT + | BitvectorT | VectorT | SequenceT | FPT | FunT | MaybeT | VariantT | StructT | RefT | Half_ | Float_ | Double_ | Quad_ | X86_80_ | DoubleDouble_ | Unicode_ | Char8_ | Char16_ | The @@ -64,6 +63,10 @@ data Keyword = Defun | DefBlock | DefGlobal | ToAny | FromAny | VectorLit_ | VectorReplicate_ | VectorIsEmpty_ | VectorSize_ | VectorGetEntry_ | VectorSetEntry_ | VectorCons_ + | MkStruct_ | GetField_ | SetField_ + | SequenceNil_ | SequenceCons_ | SequenceAppend_ + | SequenceIsNil_ | SequenceLength_ + | SequenceHead_ | SequenceTail_ | SequenceUncons_ | Deref | Ref | EmptyRef | Jump_ | Return_ | Branch_ | MaybeBranch_ | TailCall_ | Error_ | Output_ | Case | Print_ | PrintLn_ @@ -130,9 +133,11 @@ keywords = , ("String" , StringT) , ("Bitvector" , BitvectorT) , ("Vector", VectorT) + , ("Sequence", SequenceT) , ("->", FunT) , ("Maybe", MaybeT) , ("Variant", VariantT) + , ("Struct", StructT) -- string sorts , ("Unicode", Unicode_) @@ -182,6 +187,11 @@ keywords = , ("inj", Inj) , ("proj", Proj) + -- Structs + , ("struct", MkStruct_) + , ("get-field", GetField_) + , ("set-field", SetField_) + -- Maybe , ("just" , Just_) , ("nothing" , Nothing_) @@ -196,6 +206,16 @@ keywords = , ("vector-set", VectorSetEntry_) , ("vector-cons", VectorCons_) + -- Sequences + , ("seq-nil", SequenceNil_) + , ("seq-cons", SequenceCons_) + , ("seq-append", SequenceAppend_) + , ("seq-nil?", SequenceIsNil_) + , ("seq-length", SequenceLength_) + , ("seq-head", SequenceHead_) + , ("seq-tail", SequenceTail_) + , ("seq-uncons", SequenceUncons_) + -- strings , ("show", Show) , ("string-concat", StringConcat_) diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/Concrete.hs b/crucible-syntax/src/Lang/Crucible/Syntax/Concrete.hs index 48b9194f8..b3d7444cc 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/Concrete.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/Concrete.hs @@ -244,16 +244,16 @@ repUntilLast sp = describe "zero or more followed by one" $ repUntilLast' sp (cons p emptyList <&> \(x, ()) -> ([], x)) <|> (cons p (repUntilLast' p) <&> \(x, (xs, lst)) -> (x:xs, lst)) -isBaseType :: MonadSyntax Atomic m => m (Some BaseTypeRepr) -isBaseType = +_isBaseType :: MonadSyntax Atomic m => m (Some BaseTypeRepr) +_isBaseType = describe "base type" $ do Some tp <- isType case asBaseType tp of NotBaseType -> empty AsBaseType bt -> return (Some bt) -isFloatingType :: MonadSyntax Atomic m => m (Some FloatInfoRepr) -isFloatingType = +_isFloatingType :: MonadSyntax Atomic m => m (Some FloatInfoRepr) +_isFloatingType = describe "floating-point type" $ do Some tp <- isType case tp of @@ -288,7 +288,7 @@ stringSort = isType :: MonadSyntax Atomic m => m (Some TypeRepr) isType = describe "type" $ call - (atomicType <|> stringT <|> vector <|> ref <|> bv <|> fp <|> fun <|> maybeT <|> var) + (atomicType <|> stringT <|> vector <|> seqt <|> ref <|> bv <|> fp <|> fun <|> maybeT <|> var <|> struct) where atomicType = @@ -303,6 +303,7 @@ isType = , kw CharT $> Some CharRepr ] vector = unary VectorT isType <&> \(Some t) -> Some (VectorRepr t) + seqt = unary SequenceT isType <&> \(Some t) -> Some (SequenceRepr t) ref = unary RefT isType <&> \(Some t) -> Some (ReferenceRepr t) bv :: MonadSyntax Atomic m => m (Some TypeRepr) bv = do BoundedNat len <- unary BitvectorT posNat @@ -323,6 +324,8 @@ isType = var :: MonadSyntax Atomic m => m (Some TypeRepr) var = cons (kw VariantT) (rep isType) <&> \((), toCtx -> Some tys) -> Some (VariantRepr tys) + struct :: MonadSyntax Atomic m => m (Some TypeRepr) + struct = cons (kw StructT) (rep isType) <&> \((), toCtx -> Some tys) -> Some (StructRepr tys) someExprType :: SomeExpr s -> Maybe (Some TypeRepr) someExprType (SomeE tpr _) = Just (Some tpr) @@ -439,6 +442,9 @@ synthExpr typeHint = toAny <|> fromAny <|> stringAppend <|> stringEmpty <|> stringLength <|> showExpr <|> just <|> nothing <|> fromJust_ <|> injection <|> projection <|> vecLit <|> vecCons <|> vecRep <|> vecLen <|> vecEmptyP <|> vecGet <|> vecSet <|> + struct <|> getField <|> setField <|> + seqNil <|> seqCons <|> seqAppend <|> seqNilP <|> seqLen <|> + seqHead <|> seqTail <|> seqUncons <|> ite <|> intLit <|> rationalLit <|> intp <|> binaryToFp <|> fpToBinary <|> realToFp <|> fpToReal <|> ubvToFloat <|> floatToUBV <|> sbvToFloat <|> floatToSBV <|> @@ -449,7 +455,6 @@ synthExpr typeHint = -- Syntactic constructs still to add (see issue #74) -- BvToInteger, SbvToInteger, BvToNat --- MkStruct, GetStruct, SetStruct -- NatToInteger, IntegerToReal -- RealRound, RealFloor, RealCeil -- IntegerToBV, RealToNat @@ -919,6 +924,132 @@ synthExpr typeHint = return $ SomeE (VectorRepr elemT) $ EApp $ VectorSetEntry elemT vec n elt _ -> later $ describe "argument with vector type" empty) + struct :: m (SomeExpr s) + struct = describe "struct literal" $ followedBy (kw MkStruct_) (commit *> + do ls <- case typeHint of + Just (Some (StructRepr ctx)) -> + list (toListFC (\t -> forceSynth =<< synthExpr (Just (Some t))) ctx) + Just (Some t) -> later $ describe ("value of type " <> T.pack (show t) <> " but got struct") empty + Nothing -> rep (forceSynth =<< synthExpr Nothing) + pure $! buildStruct ls) + + getField :: m (SomeExpr s) + getField = + describe "struct field projection" $ + followedBy (kw GetField_) (commit *> + depCons int (\n -> + depCons synth (\(Pair t e) -> + case t of + StructRepr ts -> + case Ctx.intIndex (fromInteger n) (Ctx.size ts) of + Nothing -> + describe (T.pack (show n) <> " is an invalid index into " <> T.pack (show ts)) empty + Just (Some idx) -> + do let ty = ts^.ixF' idx + return $ SomeE ty $ EApp $ GetStruct e idx ty + _ -> describe ("struct type (got " <> T.pack (show t) <> ")") empty))) + + setField :: m (SomeExpr s) + setField = describe "update to a struct type" $ + followedBy (kw SetField_) (commit *> + depConsCond (forceSynth =<< synthExpr typeHint) (\ (Pair tp e) -> + case tp of + StructRepr ts -> Right <$> + depConsCond int (\n -> + case Ctx.intIndex (fromInteger n) (Ctx.size ts) of + Nothing -> pure (Left (T.pack (show n) <> " is an invalid index into " <> T.pack (show ts))) + Just (Some idx) -> Right <$> + do let ty = ts^.ixF' idx + (v,()) <- cons (check ty) emptyList + pure $ SomeE (StructRepr ts) $ EApp $ SetStruct ts e idx v) + _ -> pure $ Left $ ("struct type, but got " <> T.pack (show tp)))) + + seqNil :: m (SomeExpr s) + seqNil = + do Some t <- unary SequenceNil_ isType + return $ SomeE (SequenceRepr t) $ EApp $ SequenceNil t + <|> + kw SequenceNil_ *> + case typeHint of + Just (Some (SequenceRepr t)) -> + return $ SomeE (SequenceRepr t) $ EApp $ SequenceNil t + Just (Some t) -> + later $ describe ("value of type " <> T.pack (show t)) empty + Nothing -> + later $ describe ("unambiguous nil value") empty + + seqCons :: m (SomeExpr s) + seqCons = + do let newhint = case typeHint of + Just (Some (SequenceRepr t)) -> Just (Some t) + _ -> Nothing + (a, d) <- binary SequenceCons_ (later (synthExpr newhint)) (later (synthExpr typeHint)) + let g Nothing = Nothing + g (Just (Some t)) = Just (Some (SequenceRepr t)) + case join (find isJust [ typeHint, g (someExprType a), someExprType d ]) of + Just (Some (SequenceRepr t)) -> + SomeE (SequenceRepr t) . EApp <$> (SequenceCons t <$> evalSomeExpr t a <*> evalSomeExpr (SequenceRepr t) d) + _ -> later $ describe "unambiguous sequence cons (add a type ascription to disambiguate)" empty + + seqAppend :: m (SomeExpr s) + seqAppend = + do (x, y) <- binary SequenceAppend_ (later (synthExpr typeHint)) (later (synthExpr typeHint)) + case join (find isJust [ typeHint, someExprType x, someExprType y ]) of + Just (Some (SequenceRepr t)) -> + SomeE (SequenceRepr t) . EApp <$> + (SequenceAppend t <$> evalSomeExpr (SequenceRepr t) x <*> evalSomeExpr (SequenceRepr t) y) + _ -> later $ describe "unambiguous sequence append (add a type ascription to disambiguate)" empty + + seqNilP :: m (SomeExpr s) + seqNilP = + do Pair t e <- unary SequenceIsNil_ synth + case t of + SequenceRepr t' -> return $ SomeE BoolRepr $ EApp $ SequenceIsNil t' e + other -> later $ describe ("sequence (found " <> T.pack (show other) <> ")") empty + + seqLen :: m (SomeExpr s) + seqLen = + do Pair t e <- unary SequenceLength_ synth + case t of + SequenceRepr t' -> return $ SomeE NatRepr $ EApp $ SequenceLength t' e + other -> later $ describe ("sequence (found " <> T.pack (show other) <> ")") empty + + seqHead :: m (SomeExpr s) + seqHead = + do let newhint = case typeHint of + Just (Some (MaybeRepr t)) -> Just (Some (SequenceRepr t)) + _ -> Nothing + (Pair t e) <- + unary SequenceHead_ (forceSynth =<< synthExpr newhint) + case t of + SequenceRepr elemT -> return $ SomeE (MaybeRepr elemT) $ EApp $ SequenceHead elemT e + other -> later $ describe ("sequence (found " <> T.pack (show other) <> ")") empty + + seqTail :: m (SomeExpr s) + seqTail = + do let newhint = case typeHint of + Just (Some (MaybeRepr t)) -> Just (Some t) + _ -> Nothing + (Pair t e) <- + unary SequenceTail_ (forceSynth =<< synthExpr newhint) + case t of + SequenceRepr elemT -> return $ SomeE (MaybeRepr (SequenceRepr elemT)) $ EApp $ SequenceTail elemT e + other -> later $ describe ("sequence (found " <> T.pack (show other) <> ")") empty + + seqUncons :: m (SomeExpr s) + seqUncons = + do let newhint = case typeHint of + Just (Some (MaybeRepr (StructRepr (Ctx.Empty Ctx.:> t Ctx.:> _)))) -> + Just (Some (SequenceRepr t)) + _ -> Nothing + (Pair t e) <- + unary SequenceUncons_ (forceSynth =<< synthExpr newhint) + case t of + SequenceRepr elemT -> + return $ SomeE (MaybeRepr (StructRepr (Ctx.Empty Ctx.:> elemT Ctx.:> SequenceRepr elemT))) $ + EApp $ SequenceUncons elemT e + other -> later $ describe ("sequence (found " <> T.pack (show other) <> ")") empty + showExpr :: m (SomeExpr s) showExpr = do Pair t1 e <- unary Show synth @@ -933,6 +1064,14 @@ synthExpr typeHint = return $ SomeE (StringRepr UnicodeRepr) $ EApp $ ShowValue bt e _ -> later $ describe ("base or floating point type, but got " <> T.pack (show t1)) empty + +buildStruct :: [Pair TypeRepr (E s)] -> SomeExpr s +buildStruct = loop Ctx.Empty Ctx.Empty + where + loop :: Ctx.Assignment TypeRepr ctx -> Ctx.Assignment (E s) ctx -> [Pair TypeRepr (E s)] -> SomeExpr s + loop tps vs [] = SomeE (StructRepr tps) (EApp (MkStruct tps vs)) + loop tps vs (Pair tp x:xs) = loop (tps Ctx.:> tp) (vs Ctx.:> x) xs + data NatHint = NoHint | forall w. (1 <= w) => NatHint (NatRepr w) diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/ExprParse.hs b/crucible-syntax/src/Lang/Crucible/Syntax/ExprParse.hs index 0d2e73b67..fc4488db9 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/ExprParse.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/ExprParse.hs @@ -66,7 +66,6 @@ module Lang.Crucible.Syntax.ExprParse import Control.Applicative import Control.Lens hiding (List, cons, backwards) -import Control.Monad (ap) import Control.Monad.Reader import qualified Control.Monad.State.Strict as Strict import qualified Control.Monad.State.Lazy as Lazy @@ -79,7 +78,6 @@ import Data.Foldable as Foldable import Data.List import qualified Data.List.NonEmpty as NE import Data.List.NonEmpty (NonEmpty(..)) -import Data.Semigroup (Semigroup(..)) import Data.String import Data.Text (Text) import qualified Data.Text as T diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/Overrides.hs b/crucible-syntax/src/Lang/Crucible/Syntax/Overrides.hs index 3d9e910f0..659b7c679 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/Overrides.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/Overrides.hs @@ -28,7 +28,6 @@ import Lang.Crucible.Backend import Lang.Crucible.Types import Lang.Crucible.FunctionHandle import Lang.Crucible.Simulator -import Lang.Crucible.Simulator.SimError (ppSimError) setupOverrides :: diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/Prog.hs b/crucible-syntax/src/Lang/Crucible/Syntax/Prog.hs index 720f431e8..7b4c6375a 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/Prog.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/Prog.hs @@ -106,13 +106,13 @@ simulateProgram fn theInput outh profh opts setup = case find isMain cs of Just (ACFG Ctx.Empty retType mn) -> do let mainHdl = cfgHandle mn - let fnBindings = fnBindingsFromList + let fns = fnBindingsFromList [ case toSSA g of C.SomeCFG ssa -> FnBinding (cfgHandle g) (UseCFG ssa (postdomInfo ssa)) | ACFG _ _ g <- cs ] - let simCtx = initSimContext sym emptyIntrinsicTypes ha outh fnBindings emptyExtensionImpl () + let simCtx = initSimContext sym emptyIntrinsicTypes ha outh fns emptyExtensionImpl () let simSt = InitialState simCtx emptyGlobals defaultAbortHandler retType $ runOverrideSim retType $ do mapM_ (registerFnBinding . fst) ovrs diff --git a/crucible-syntax/src/Lang/Crucible/Syntax/SExpr.hs b/crucible-syntax/src/Lang/Crucible/Syntax/SExpr.hs index 78886eecc..908e526f0 100644 --- a/crucible-syntax/src/Lang/Crucible/Syntax/SExpr.hs +++ b/crucible-syntax/src/Lang/Crucible/Syntax/SExpr.hs @@ -27,7 +27,6 @@ module Lang.Crucible.Syntax.SExpr ) where import Data.Char (isDigit, isLetter) -import Data.Semigroup (Semigroup(..)) import Data.Text (Text) import qualified Data.Text as T import Data.Void diff --git a/crucible-syntax/test-data/parser-tests/structs.cbl b/crucible-syntax/test-data/parser-tests/structs.cbl new file mode 100644 index 000000000..d3d781806 --- /dev/null +++ b/crucible-syntax/test-data/parser-tests/structs.cbl @@ -0,0 +1,12 @@ +(defun @structs ((x (Struct Bool Integer))) (Struct Unit Nat Bool) + (start st: + (let b (get-field 0 x)) + (let i (get-field 1 x)) + + (let r1 (struct () (the Nat 5) b)) + (let r2 (set-field r1 1 42)) + (let r3 (set-field r2 2 #f)) + + (return r3) + ) +) diff --git a/crucible-syntax/test-data/parser-tests/structs.out.good b/crucible-syntax/test-data/parser-tests/structs.out.good new file mode 100644 index 000000000..38cde2530 --- /dev/null +++ b/crucible-syntax/test-data/parser-tests/structs.out.good @@ -0,0 +1,35 @@ +(defun + @structs + ((x (Struct Bool Integer))) + (Struct Unit Nat Bool) + (start st: + (let b (get-field 0 x)) + (let i (get-field 1 x)) + (let r1 (struct () (the Nat 5) b)) + (let r2 (set-field r1 1 42)) + (let r3 (set-field r2 2 #f)) + (return r3))) + +structs +%0 + % 3:12 + $1 = getStruct($0, 0, BoolRepr) + % 4:12 + $2 = getStruct($0, 1, IntegerRepr) + % 6:13 + $3 = emptyApp() + % 6:13 + $4 = natLit(5) + % 6:13 + $5 = mkStruct([UnitRepr, NatRepr, BoolRepr], [$3, $4, $1]) + % 7:13 + $6 = natLit(42) + % 7:13 + $7 = setStruct([UnitRepr, NatRepr, BoolRepr], $5, 1, $6) + % 8:13 + $8 = boolLit(False) + % 8:13 + $9 = setStruct([UnitRepr, NatRepr, BoolRepr], $7, 2, $8) + % 10:5 + return $9 + % no postdom diff --git a/crucible-syntax/test-data/simulator-tests/seq-test1.cbl b/crucible-syntax/test-data/simulator-tests/seq-test1.cbl new file mode 100644 index 000000000..609e6f799 --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test1.cbl @@ -0,0 +1,32 @@ +(defun @main () Unit + (start start: + (let n (the (Sequence Nat) seq-nil)) + (assert! (seq-nil? n) "nil test") + (assert! (equal? 0 (seq-length n)) "nil length test") + + (let s1 (seq-cons 5 n)) + + (assert! (not (seq-nil? s1)) "cons test") + (assert! (equal? 1 (seq-length s1)) "cons length test") + + (let v (from-just (seq-head s1) "head s1")) + (let t (from-just (seq-tail s1) "tail s1")) + (let u (from-just (seq-uncons s1) "uncons s1")) + + (let v2 (get-field 0 u)) + (let t2 (get-field 1 u)) + + (assert! (equal? 5 v) "head value test") + (assert! (equal? v v2) "head equal test") + (assert! (seq-nil? t) "tail nil test") + (assert! (seq-nil? t2) "uncons tail nil test") + + (let s2 (seq-append s1 (seq-cons 42 (seq-nil Nat)))) + (assert! (equal? 2 (seq-length s2)) "append length") + (assert! (not (seq-nil? s2)) "append non-nil") + + (let v3 (from-just (seq-head (from-just (seq-tail s2) "cdr s2")) "cadr s2")) + (assert! (equal? 42 v3) "cadr s2 test") + + (return ())) +) diff --git a/crucible-syntax/test-data/simulator-tests/seq-test1.out.good b/crucible-syntax/test-data/simulator-tests/seq-test1.out.good new file mode 100644 index 000000000..a0dc0d91d --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test1.out.good @@ -0,0 +1,4 @@ +==== Begin Simulation ==== + +==== Finish Simulation ==== +==== No proof obligations ==== diff --git a/crucible-syntax/test-data/simulator-tests/seq-test2.cbl b/crucible-syntax/test-data/simulator-tests/seq-test2.cbl new file mode 100644 index 000000000..f772bf2d6 --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test2.cbl @@ -0,0 +1,47 @@ +(defun @main () Unit + (registers + ($s (Sequence Integer)) + ) + + (start start: + (set-register! $s seq-nil) + (let b (fresh Bool)) + (let x (fresh Integer)) + (let y (fresh Integer)) + (let z (fresh Integer)) + (branch b l1: l2:)) + + (defblock l1: + (set-register! $s (seq-cons x (seq-cons y $s))) + (jump l3:) + ) + + (defblock l2: + (set-register! $s (seq-cons z $s)) + (jump l3:) + ) + + (defblock l3: + (assert! (<= 1 (seq-length $s)) "length test") + (let u (from-just (seq-uncons $s) "uncons")) + + (let v (if b x z)) + (assert! (equal? v (get-field 0 u)) "head check") + (assert! (equal? (seq-nil? (get-field 1 u)) (not b)) "tail check") + + (let mu2 (seq-uncons (get-field 1 u))) + (maybe-branch mu2 j: n:)) + + (defblock (j: u2 (Struct Integer (Sequence Integer))) + (let v2 (get-field 0 u2)) + (let t2 (get-field 1 u2)) + (assert! b "tail 2 condition check") + (assert! (equal? y v2) "tail 2 value test") + (assert! (seq-nil? t2) "tail 2 nil test") + (return ()) + ) + + (defblock n: + (assert! (not b) "tail 2 none check") + (return ())) +) diff --git a/crucible-syntax/test-data/simulator-tests/seq-test2.out.good b/crucible-syntax/test-data/simulator-tests/seq-test2.out.good new file mode 100644 index 000000000..ca82ee72e --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test2.out.good @@ -0,0 +1,26 @@ +==== Begin Simulation ==== + +==== Finish Simulation ==== +==== Proof obligations ==== + +Prove: + test-data/simulator-tests/seq-test2.cbl:29:5: error: in main + head check + eq (ite cb@0:b cx@1:i cz@3:i) (ite cb@0:b cx@1:i cz@3:i) +PROVED +Assuming: +* The branch in main from test-data/simulator-tests/seq-test2.cbl:33:5 to test-data/simulator-tests/seq-test2.cbl:36:13 + cb@0:b +Prove: + test-data/simulator-tests/seq-test2.cbl:38:5: error: in main + tail 2 condition check + cb@0:b +PROVED +Assuming: +* The branch in main from test-data/simulator-tests/seq-test2.cbl:33:5 to test-data/simulator-tests/seq-test2.cbl:45:14 + not cb@0:b +Prove: + test-data/simulator-tests/seq-test2.cbl:45:5: error: in main + tail 2 none check + not cb@0:b +PROVED diff --git a/crucible-syntax/test-data/simulator-tests/seq-test3.cbl b/crucible-syntax/test-data/simulator-tests/seq-test3.cbl new file mode 100644 index 000000000..271221e7c --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test3.cbl @@ -0,0 +1,106 @@ +(defun @main () Unit + (registers + ($s1 (Sequence Integer)) + ($s2 (Sequence Integer)) + ) + + (start start: + (set-register! $s1 (seq-cons 42 seq-nil)) + (set-register! $s2 seq-nil) + + (let b1 (fresh Bool)) + (let b2 (fresh Bool)) + (let x (fresh Integer)) + (let y (fresh Integer)) + (let z (fresh Integer)) + (let w (fresh Integer)) + + (branch b1 l1: l2:)) + + (defblock l1: + (set-register! $s1 (seq-cons x $s1)) + (jump l3:) + ) + + (defblock l2: + (set-register! $s1 (seq-cons y $s1)) + (jump l3:) + ) + + (defblock l3: + (branch b2 l4: l5:) + ) + + + (defblock l4: + (set-register! $s2 (seq-cons z $s2)) + (jump l6:) + ) + + (defblock l5: + (set-register! $s2 (seq-cons w $s2)) + (jump l6:) + ) + + (defblock l6: + (let s (seq-append $s1 $s2)) + + (let _0 (funcall @eqseq s + (seq-cons (if b1 x y) (seq-cons 42 (seq-cons (if b2 z w) (seq-nil Integer)))))) + + (let u1 (from-just (seq-uncons s) "uncons 1")) + (let v1 (get-field 0 u1)) + (let t1 (get-field 1 u1)) + + (let v1alt (from-just (seq-head s) "head 1")) + (assert! (equal? v1 v1alt) "head 1 eq check") + (assert! (equal? v1 (if b1 x y)) "head 1 check") + + (let t1alt (from-just (seq-tail s) "tail 1")) + (let _1 (funcall @eqseq t1 t1alt)) + (let _2 (funcall @eqseq t1 (seq-cons 42 (seq-cons (if b2 z w) (seq-nil Integer))))) + + (let u2 (from-just (seq-uncons t1) "uncons 2")) + (let v2 (get-field 0 u2)) + (let t2 (get-field 1 u2)) + + (let v2alt (from-just (seq-head t1) "head 2")) + (assert! (equal? v2 v2alt) "head 2 eq check") + (assert! (equal? v2 42) "head 2 check") + + (let t2alt (from-just (seq-tail t1) "tail 2")) + (let _3 (funcall @eqseq t2 t2alt)) + (let _4 (funcall @eqseq t2 (seq-cons (if b2 z w) (seq-nil Integer)))) + + (return ()) + ) +) + +(defun @eqseq ( (xs (Sequence Integer)) (ys (Sequence Integer)) ) Unit + (registers + ($xs (Sequence Integer)) + ($ys (Sequence Integer))) + + (start st: + (set-register! $xs xs) + (set-register! $ys ys) + (jump loop:)) + + (defblock loop: + (maybe-branch (seq-uncons $xs) xsj: xsn:) + ) + + (defblock (xsj: uxs (Struct Integer (Sequence Integer))) + (let uys (from-just (seq-uncons $ys) "sequence length mismatch!")) + (assert! (equal? (get-field 0 uxs) (get-field 0 uys)) "value mismatch!") + + (set-register! $xs (get-field 1 uxs)) + (set-register! $ys (get-field 1 uys)) + (jump loop:) + ) + + (defblock xsn: + (assert! (seq-nil? $ys) "sequence length mismatch!") + (return ()) + ) +) diff --git a/crucible-syntax/test-data/simulator-tests/seq-test3.out.good b/crucible-syntax/test-data/simulator-tests/seq-test3.out.good new file mode 100644 index 000000000..9651b034b --- /dev/null +++ b/crucible-syntax/test-data/simulator-tests/seq-test3.out.good @@ -0,0 +1,52 @@ +==== Begin Simulation ==== + +==== Finish Simulation ==== +==== Proof obligations ==== +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb1@0:b cx@2:i cy@3:i) (ite cb1@0:b cx@2:i cy@3:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb2@1:b cz@4:i cw@5:i) (ite cb2@1:b cz@4:i cw@5:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:56:5: error: in main + head 1 eq check + eq (ite cb1@0:b cx@2:i cy@3:i) (ite cb1@0:b cx@2:i cy@3:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:57:5: error: in main + head 1 check + eq (ite cb1@0:b cx@2:i cy@3:i) (ite cb1@0:b cx@2:i cy@3:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb2@1:b cz@4:i cw@5:i) (ite cb2@1:b cz@4:i cw@5:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb2@1:b cz@4:i cw@5:i) (ite cb2@1:b cz@4:i cw@5:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb2@1:b cz@4:i cw@5:i) (ite cb2@1:b cz@4:i cw@5:i) +PROVED +Assuming: +Prove: + test-data/simulator-tests/seq-test3.cbl:95:5: error: in eqseq + value mismatch! + eq (ite cb2@1:b cz@4:i cw@5:i) (ite cb2@1:b cz@4:i cw@5:i) +PROVED diff --git a/crucible/crucible.cabal b/crucible/crucible.cabal index 58c8b294f..7392dcef8 100644 --- a/crucible/crucible.cabal +++ b/crucible/crucible.cabal @@ -110,6 +110,7 @@ library Lang.Crucible.Simulator.RegMap Lang.Crucible.Simulator.RegValue Lang.Crucible.Simulator.SimError + Lang.Crucible.Simulator.SymSequence Lang.Crucible.Syntax Lang.Crucible.Types Lang.Crucible.Vector diff --git a/crucible/src/Lang/Crucible/CFG/Expr.hs b/crucible/src/Lang/Crucible/CFG/Expr.hs index 795f30226..cf90a0248 100644 --- a/crucible/src/Lang/Crucible/CFG/Expr.hs +++ b/crucible/src/Lang/Crucible/CFG/Expr.hs @@ -487,6 +487,50 @@ data App (ext :: Type) (f :: CrucibleType -> Type) (tp :: CrucibleType) where -> !(f (RecursiveType nm ctx)) -> App ext f (UnrollType nm ctx) + ---------------------------------------------------------------------- + -- Sequences + + -- Create an empty sequence + SequenceNil :: !(TypeRepr tp) -> App ext f (SequenceType tp) + + -- Add a new value to the front of a sequence + SequenceCons :: !(TypeRepr tp) + -> !(f tp) + -> !(f (SequenceType tp)) + -> App ext f (SequenceType tp) + + -- Append two sequences + SequenceAppend :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> !(f (SequenceType tp)) + -> App ext f (SequenceType tp) + + -- Test if a sequence is nil + SequenceIsNil :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> App ext f BoolType + + -- Return the length of a sequence + SequenceLength :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> App ext f NatType + + -- Return the head of a sesquence, if it is non-nil. + SequenceHead :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> App ext f (MaybeType tp) + + -- Return the tail of a sequence, if it is non-nil. + SequenceTail :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> App ext f (MaybeType (SequenceType tp)) + + -- Deconstruct a sequence. Return nothing if nil, + -- return the head and tail if non-nil. + SequenceUncons :: !(TypeRepr tp) + -> !(f (SequenceType tp)) + -> App ext f (MaybeType (StructType (EmptyCtx ::> tp ::> SequenceType tp))) + ---------------------------------------------------------------------- -- Vector @@ -1173,6 +1217,18 @@ instance TypeApp (ExprExtension ext) => TypeApp (App ext) where VectorSetEntry tp _ _ _ -> VectorRepr tp VectorCons tp _ _ -> VectorRepr tp + ---------------------------------------------------------------------- + -- Sequence + SequenceNil tpr -> SequenceRepr tpr + SequenceCons tpr _ _ -> SequenceRepr tpr + SequenceAppend tpr _ _ -> SequenceRepr tpr + SequenceIsNil _ _ -> knownRepr + SequenceHead tpr _ -> MaybeRepr tpr + SequenceUncons tpr _ -> + MaybeRepr (StructRepr (Ctx.Empty Ctx.:> tpr Ctx.:> SequenceRepr tpr)) + SequenceLength{} -> knownRepr + SequenceTail tpr _ -> MaybeRepr (SequenceRepr tpr) + ---------------------------------------------------------------------- -- SymbolicArrayType diff --git a/crucible/src/Lang/Crucible/Simulator/Evaluation.hs b/crucible/src/Lang/Crucible/Simulator/Evaluation.hs index e818c6876..f1cbbcdab 100644 --- a/crucible/src/Lang/Crucible/Simulator/Evaluation.hs +++ b/crucible/src/Lang/Crucible/Simulator/Evaluation.hs @@ -66,6 +66,7 @@ import Lang.Crucible.CFG.Expr import Lang.Crucible.Simulator.Intrinsics import Lang.Crucible.Simulator.RegMap import Lang.Crucible.Simulator.SimError +import Lang.Crucible.Simulator.SymSequence import Lang.Crucible.Types import Lang.Crucible.Utils.MuxTree @@ -470,6 +471,27 @@ evalApp sym itefns _logFn evalExt (evalSub :: forall tp. f tp -> IO (RegValue sy v <- evalSub v_expr return $ V.cons e v + -------------------------------------------------------------------- + -- Sequence + + SequenceNil _tpr -> nilSymSequence sym + SequenceCons _tpr x xs -> + join $ consSymSequence sym <$> evalSub x <*> evalSub xs + SequenceAppend _tpr xs ys -> + join $ appendSymSequence sym <$> evalSub xs <*> evalSub ys + SequenceIsNil _tpr xs -> + isNilSymSequence sym =<< evalSub xs + SequenceLength _tpr xs -> + lengthSymSequence sym =<< evalSub xs + SequenceHead tpr xs -> + headSymSequence sym (muxRegForType sym itefns tpr) =<< evalSub xs + SequenceTail _tpr xs -> + tailSymSequence sym =<< evalSub xs + SequenceUncons tpr xs -> + do xs' <- evalSub xs + mu <- unconsSymSequence sym (muxRegForType sym itefns tpr) xs' + traverse (\ (h,tl) -> pure (Ctx.Empty Ctx.:> RV h Ctx.:> RV tl)) mu + -------------------------------------------------------------------- -- Symbolic Arrays diff --git a/crucible/src/Lang/Crucible/Simulator/RegMap.hs b/crucible/src/Lang/Crucible/Simulator/RegMap.hs index feb844369..41dd17815 100644 --- a/crucible/src/Lang/Crucible/Simulator/RegMap.hs +++ b/crucible/src/Lang/Crucible/Simulator/RegMap.hs @@ -233,6 +233,7 @@ muxRegForType s itefns p = MaybeRepr r -> mergePartExpr s (muxRegForType s itefns r) VectorRepr r -> muxVector s (muxRegForType s itefns r) + SequenceRepr _r -> muxSymSequence s StringMapRepr r -> muxStringMap s (muxRegForType s itefns r) SymbolicArrayRepr{} -> arrayIte s SymbolicStructRepr{} -> structIte s diff --git a/crucible/src/Lang/Crucible/Simulator/RegValue.hs b/crucible/src/Lang/Crucible/Simulator/RegValue.hs index 1bcba099e..aae256ea2 100644 --- a/crucible/src/Lang/Crucible/Simulator/RegValue.hs +++ b/crucible/src/Lang/Crucible/Simulator/RegValue.hs @@ -25,8 +25,6 @@ module Lang.Crucible.Simulator.RegValue ( RegValue , CanMux(..) , RegValue'(..) - , VariantBranch(..) - , injectVariant , MuxFn -- * Register values @@ -34,6 +32,10 @@ module Lang.Crucible.Simulator.RegValue , FnVal(..) , fnValType , RolledType(..) + , SymSequence(..) + + , VariantBranch(..) + , injectVariant -- * Value mux functions , ValMuxFn @@ -44,6 +46,7 @@ module Lang.Crucible.Simulator.RegValue , muxStruct , muxVariant , muxVector + , muxSymSequence , muxHandle ) where @@ -70,6 +73,7 @@ import What4.WordMap import Lang.Crucible.FunctionHandle import Lang.Crucible.Simulator.Intrinsics import Lang.Crucible.Simulator.SimError +import Lang.Crucible.Simulator.SymSequence import Lang.Crucible.Types import Lang.Crucible.Utils.MuxTree import Lang.Crucible.Backend @@ -87,6 +91,7 @@ type family RegValue (sym :: Type) (tp :: CrucibleType) :: Type where RegValue sym (FunctionHandleType a r) = FnVal sym a r RegValue sym (MaybeType tp) = PartExpr (Pred sym) (RegValue sym tp) RegValue sym (VectorType tp) = V.Vector (RegValue sym tp) + RegValue sym (SequenceType tp) = SymSequence sym (RegValue sym tp) RegValue sym (StructType ctx) = Ctx.Assignment (RegValue' sym) ctx RegValue sym (VariantType ctx) = Ctx.Assignment (VariantBranch sym) ctx RegValue sym (ReferenceType tp) = MuxTree sym (RefCell tp) diff --git a/crucible/src/Lang/Crucible/Simulator/SymSequence.hs b/crucible/src/Lang/Crucible/Simulator/SymSequence.hs new file mode 100644 index 000000000..6c5a8e52e --- /dev/null +++ b/crucible/src/Lang/Crucible/Simulator/SymSequence.hs @@ -0,0 +1,512 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} + +-- Needed for Pretty instance +{-# LANGUAGE UndecidableInstances #-} +module Lang.Crucible.Simulator.SymSequence +( SymSequence(..) +, nilSymSequence +, consSymSequence +, appendSymSequence +, muxSymSequence +, isNilSymSequence +, lengthSymSequence +, headSymSequence +, tailSymSequence +, unconsSymSequence +, traverseSymSequence +, concreteizeSymSequence +, prettySymSequence +) where + +import Control.Monad.State +import Data.Functor.Const +import Data.Kind (Type) +import Data.IORef +import Data.Maybe (isJust) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Parameterized.Nonce +import qualified Data.Parameterized.Map as MapF +import Prettyprinter (Doc) +import qualified Prettyprinter as PP + +import Lang.Crucible.Types +import What4.Interface +import What4.Partial + +------------------------------------------------------------------------ +-- SymSequence + +-- | A symbolic sequence of values supporting efficent merge operations. +-- Semantically, these are essentially cons-lists, and designed to +-- support access from the front only. Nodes carry nonce values +-- that allow DAG-based traversal, which efficently supports the common +-- case where merged nodes share a common sublist. +data SymSequence sym a where + SymSequenceNil :: SymSequence sym a + + SymSequenceCons :: + !(Nonce GlobalNonceGenerator a) -> + a -> + !(SymSequence sym a) -> + SymSequence sym a + + SymSequenceAppend :: + !(Nonce GlobalNonceGenerator a) -> + !(SymSequence sym a) -> + !(SymSequence sym a) -> + SymSequence sym a + + SymSequenceMerge :: + !(Nonce GlobalNonceGenerator a) -> + !(Pred sym) -> + !(SymSequence sym a) -> + !(SymSequence sym a) -> + SymSequence sym a + +instance Eq (SymSequence sym a) where + SymSequenceNil == SymSequenceNil = True + (SymSequenceCons n1 _ _) == (SymSequenceCons n2 _ _) = + isJust (testEquality n1 n2) + (SymSequenceMerge n1 _ _ _) == (SymSequenceMerge n2 _ _ _) = + isJust (testEquality n1 n2) + (SymSequenceAppend n1 _ _) == (SymSequenceAppend n2 _ _) = + isJust (testEquality n1 n2) + _ == _ = False + +-- | Compute an if/then/else on symbolic sequences. +-- This will simply produce an internal merge node +-- except in the special case where the then and +-- else branches are sytactically identical. +muxSymSequence :: + sym -> + Pred sym -> + SymSequence sym a -> + SymSequence sym a -> + IO (SymSequence sym a) +muxSymSequence _sym p x y + | x == y = pure x + | otherwise = + do n <- freshNonce globalNonceGenerator + pure (SymSequenceMerge n p x y) + +newtype SeqCache (f :: Type -> Type) + = SeqCache (IORef (MapF.MapF (Nonce GlobalNonceGenerator) f)) + +newSeqCache :: IO (SeqCache f) +newSeqCache = SeqCache <$> newIORef MapF.empty + +-- | Compute the nonce of a sequence, if it has one +symSequenceNonce :: SymSequence sym a -> Maybe (Nonce GlobalNonceGenerator a) +symSequenceNonce SymSequenceNil = Nothing +symSequenceNonce (SymSequenceCons n _ _ ) = Just n +symSequenceNonce (SymSequenceAppend n _ _) = Just n +symSequenceNonce (SymSequenceMerge n _ _ _) = Just n + +{-# SPECIALIZE + evalWithFreshCache :: + ((SymSequence sym a -> IO (f a)) -> SymSequence sym a -> IO (f a)) -> + (SymSequence sym a -> IO (f a)) + #-} + +evalWithFreshCache :: MonadIO m => + ((SymSequence sym a -> m (f a)) -> SymSequence sym a -> m (f a)) -> + (SymSequence sym a -> m (f a)) +evalWithFreshCache fn s = + do c <- liftIO newSeqCache + evalWithCache c fn s + +{-# SPECIALIZE + evalWithCache :: + SeqCache f -> + ((SymSequence sym a -> IO (f a)) -> SymSequence sym a -> IO (f a)) -> + (SymSequence sym a -> IO (f a)) + #-} + +evalWithCache :: MonadIO m => + SeqCache f -> + ((SymSequence sym a -> m (f a)) -> SymSequence sym a -> m (f a)) -> + (SymSequence sym a -> m (f a)) +evalWithCache (SeqCache ref) fn = loop + where + loop s + | Just n <- symSequenceNonce s = + (MapF.lookup n <$> liftIO (readIORef ref)) >>= \case + Just v -> pure v + Nothing -> + do v <- fn loop s + liftIO (modifyIORef ref (MapF.insert n v)) + pure v + + | otherwise = fn loop s + +-- | Generate an empty sequence value +nilSymSequence :: sym -> IO (SymSequence sym a) +nilSymSequence _sym = pure SymSequenceNil + +-- | Cons a new value onto the front of a sequence +consSymSequence :: + sym -> + a -> + SymSequence sym a -> + IO (SymSequence sym a) +consSymSequence _sym x xs = + do n <- freshNonce globalNonceGenerator + pure (SymSequenceCons n x xs) + +-- | Append two sequences +appendSymSequence :: + sym -> + SymSequence sym a {- ^ front sequence -} -> + SymSequence sym a {- ^ back sequence -} -> + IO (SymSequence sym a) + +-- special cases, nil is the unit for append +appendSymSequence _ xs SymSequenceNil = pure xs +appendSymSequence _ SymSequenceNil ys = pure ys +-- special case, append of a singleton is cons +appendSymSequence sym (SymSequenceCons _ v SymSequenceNil) xs = + consSymSequence sym v xs +appendSymSequence _sym xs ys = + do n <- freshNonce globalNonceGenerator + pure (SymSequenceAppend n xs ys) + + +-- | Test if a sequence is nil (is empty) +isNilSymSequence :: forall sym a. + IsExprBuilder sym => + sym -> + SymSequence sym a -> + IO (Pred sym) +isNilSymSequence sym = \s -> getConst <$> evalWithFreshCache f s + where + f :: (SymSequence sym tp -> IO (Const (Pred sym) tp)) -> (SymSequence sym tp -> IO (Const (Pred sym) tp)) + f _loop SymSequenceNil{} = pure (Const (truePred sym)) + f _loop SymSequenceCons{} = pure (Const (falsePred sym)) + f loop (SymSequenceAppend _ xs ys) = + do px <- getConst <$> loop xs + Const <$> itePredM sym px (getConst <$> loop ys) (pure (falsePred sym)) + f loop (SymSequenceMerge _ p xs ys) = + Const <$> itePredM sym p (getConst <$> loop xs) (getConst <$> loop ys) + + +-- | Compute the length of a sequence +lengthSymSequence :: forall sym a. + IsExprBuilder sym => + sym -> + SymSequence sym a -> + IO (SymNat sym) +lengthSymSequence sym = \s -> getConst <$> evalWithFreshCache f s + where + f :: (SymSequence sym a -> IO (Const (SymNat sym) a)) -> (SymSequence sym a -> IO (Const (SymNat sym) a)) + f _loop SymSequenceNil = Const <$> natLit sym 0 + f loop (SymSequenceCons _ _ tl) = + do x <- getConst <$> loop tl + one <- natLit sym 1 + Const <$> natAdd sym one x + f loop (SymSequenceMerge _ p xs ys) = + do x <- getConst <$> loop xs + y <- getConst <$> loop ys + Const <$> natIte sym p x y + f loop (SymSequenceAppend _ xs ys) = + do x <- getConst <$> loop xs + y <- getConst <$> loop ys + Const <$> natAdd sym x y + + +newtype SeqHead sym a = SeqHead { getSeqHead :: PartExpr (Pred sym) a } + +-- | Compute the head of a sequence, if it has one +headSymSequence :: forall sym a. + IsExprBuilder sym => + sym -> + (Pred sym -> a -> a -> IO a) {- ^ mux function on values -} -> + SymSequence sym a -> + IO (PartExpr (Pred sym) a) +headSymSequence sym mux = \s -> getSeqHead <$> evalWithFreshCache f s + where + f' :: Pred sym -> a -> a -> PartialT sym IO a + f' c x y = PartialT (\_ p -> PE p <$> mux c x y) + + f :: (SymSequence sym a -> IO (SeqHead sym a)) -> (SymSequence sym a -> IO (SeqHead sym a)) + f _loop SymSequenceNil = pure (SeqHead Unassigned) + f _loop (SymSequenceCons _ v _) = pure (SeqHead (justPartExpr sym v)) + f loop (SymSequenceMerge _ p xs ys) = + do mhx <- getSeqHead <$> loop xs + mhy <- getSeqHead <$> loop ys + SeqHead <$> mergePartial sym f' p mhx mhy + + f loop (SymSequenceAppend _ xs ys) = + loop xs >>= \case + SeqHead Unassigned -> loop ys + SeqHead (PE px hx) + | Just True <- asConstantPred px -> pure (SeqHead (PE px hx)) + | otherwise -> + loop ys >>= \case + SeqHead Unassigned -> pure (SeqHead (PE px hx)) + SeqHead (PE py hy) -> + do p <- orPred sym px py + SeqHead <$> runPartialT sym p (f' px hx hy) + +newtype SeqUncons sym a = + SeqUncons + { getSeqUncons :: PartExpr (Pred sym) (a, SymSequence sym a) + } + +-- | Compute both the head and the tail of a sequence, if it is nonempty +unconsSymSequence :: forall sym a. + IsExprBuilder sym => + sym -> + (Pred sym -> a -> a -> IO a) {- ^ mux function on values -} -> + SymSequence sym a -> + IO (PartExpr (Pred sym) (a, SymSequence sym a)) +unconsSymSequence sym mux = \s -> getSeqUncons <$> evalWithFreshCache f s + where + f' :: Pred sym -> + (a, SymSequence sym a) -> + (a, SymSequence sym a) -> + PartialT sym IO (a, SymSequence sym a) + f' c x y = PartialT $ \_ p -> PE p <$> + do h <- mux c (fst x) (fst y) + tl <- muxSymSequence sym c (snd x) (snd y) + pure (h, tl) + + f :: (SymSequence sym a -> IO (SeqUncons sym a)) -> (SymSequence sym a -> IO (SeqUncons sym a)) + f _loop SymSequenceNil = pure (SeqUncons Unassigned) + f _loop (SymSequenceCons _ v tl) = pure (SeqUncons (justPartExpr sym (v, tl))) + f loop (SymSequenceMerge _ p xs ys) = + do ux <- getSeqUncons <$> loop xs + uy <- getSeqUncons <$> loop ys + SeqUncons <$> mergePartial sym f' p ux uy + + f loop (SymSequenceAppend _ xs ys) = + loop xs >>= \case + SeqUncons Unassigned -> loop ys + SeqUncons (PE px ux) + | Just True <- asConstantPred px -> + do t <- appendSymSequence sym (snd ux) ys + pure (SeqUncons (PE px (fst ux, t))) + + | otherwise -> + loop ys >>= \case + SeqUncons Unassigned -> pure (SeqUncons (PE px ux)) + SeqUncons (PE py uy) -> + do p <- orPred sym px py + t <- appendSymSequence sym (snd ux) ys + let ux' = (fst ux, t) + SeqUncons <$> runPartialT sym p (f' px ux' uy) + +newtype SeqTail sym tp = + SeqTail + { getSeqTail :: PartExpr (Pred sym) (SymSequence sym tp) } + +-- | Compute the tail of a sequence, if it has one +tailSymSequence :: forall sym a. + IsExprBuilder sym => + sym -> + SymSequence sym a -> + IO (PartExpr (Pred sym) (SymSequence sym a)) +tailSymSequence sym = \s -> getSeqTail <$> evalWithFreshCache f s + where + f' :: Pred sym -> + SymSequence sym a -> + SymSequence sym a -> + PartialT sym IO (SymSequence sym a) + f' c x y = PartialT $ \_ p -> PE p <$> muxSymSequence sym c x y + + f :: (SymSequence sym a -> IO (SeqTail sym a)) -> (SymSequence sym a -> IO (SeqTail sym a)) + f _loop SymSequenceNil = pure (SeqTail Unassigned) + f _loop (SymSequenceCons _ _v tl) = pure (SeqTail (justPartExpr sym tl)) + f loop (SymSequenceMerge _ p xs ys) = + do tx <- getSeqTail <$> loop xs + ty <- getSeqTail <$> loop ys + SeqTail <$> mergePartial sym f' p tx ty + f loop (SymSequenceAppend _ xs ys) = + loop xs >>= \case + SeqTail Unassigned -> loop ys + SeqTail (PE px tx) + | Just True <- asConstantPred px -> + do t <- appendSymSequence sym tx ys + pure (SeqTail (PE px t)) + + | otherwise -> + loop ys >>= \case + SeqTail Unassigned -> pure (SeqTail (PE px tx)) + SeqTail (PE py ty) -> + do p <- orPred sym px py + t <- appendSymSequence sym tx ys + SeqTail <$> runPartialT sym p (f' px t ty) + + +{-# SPECIALIZE + traverseSymSequence :: + sym -> + (a -> IO b) -> + SymSequence sym a -> + IO (SymSequence sym b) + #-} + +-- | Visit every element in the given symbolic sequence, +-- applying the given action, and constructing a new +-- sequence. The traversal is memoized, so any given +-- subsequence will be visited at most once. +traverseSymSequence :: forall m sym a b. + MonadIO m => + sym -> + (a -> m b) -> + SymSequence sym a -> + m (SymSequence sym b) +traverseSymSequence sym f = \s -> getConst <$> evalWithFreshCache fn s + where + fn :: (SymSequence sym a -> m (Const (SymSequence sym b) a)) -> + (SymSequence sym a -> m (Const (SymSequence sym b) a)) + fn _loop SymSequenceNil = pure (Const SymSequenceNil) + fn loop (SymSequenceCons _ v tl) = + do v' <- f v + tl' <- getConst <$> loop tl + liftIO (Const <$> consSymSequence sym v' tl') + fn loop (SymSequenceAppend _ xs ys) = + do xs' <- getConst <$> loop xs + ys' <- getConst <$> loop ys + liftIO (Const <$> appendSymSequence sym xs' ys') + fn loop (SymSequenceMerge _ p xs ys) = + do xs' <- getConst <$> loop xs + ys' <- getConst <$> loop ys + liftIO (Const <$> muxSymSequence sym p xs' ys') + + +-- | Using the given evaluation function for booleans, and an evaluation +-- function for values, compute a concrete sequence corresponding +-- to the given symbolic sequence. +concreteizeSymSequence :: + (Pred sym -> IO Bool) {- ^ evaluation for booleans -} -> + (a -> IO b) {- ^ evaluation for values -} -> + SymSequence sym a -> IO [b] +concreteizeSymSequence conc eval = loop + where + loop SymSequenceNil = pure [] + loop (SymSequenceCons _ v tl) = (:) <$> eval v <*> loop tl + loop (SymSequenceAppend _ xs ys) = (++) <$> loop xs <*> loop ys + loop (SymSequenceMerge _ p xs ys) = + do b <- conc p + if b then loop xs else loop ys + +instance (IsExpr (SymExpr sym), PP.Pretty a) => PP.Pretty (SymSequence sym a) where + pretty = prettySymSequence PP.pretty + +-- | Given a pretty printer for elements, +-- print a symbolic sequence. +prettySymSequence :: IsExpr (SymExpr sym) => + (a -> Doc ann) -> + SymSequence sym a -> + Doc ann +prettySymSequence ppa s = if Map.null bs then x else letlayout + where + occMap = computeOccMap s mempty + (x,bs) = runState (prettyAux ppa occMap s) mempty + letlayout = PP.vcat + ["let" PP.<+> (PP.align (PP.vcat [ letbind n d | (n,d) <- Map.toList bs ])) + ," in" PP.<+> x + ] + letbind n d = ppSeqNonce n PP.<+> "=" PP.<+> PP.align d + +computeOccMap :: + SymSequence sym a -> + Map (Nonce GlobalNonceGenerator a) Integer -> + Map (Nonce GlobalNonceGenerator a) Integer +computeOccMap = loop + where + visit n k m + | Just i <- Map.lookup n m = Map.insert n (i+1) m + | otherwise = k (Map.insert n 1 m) + + loop SymSequenceNil = id + loop (SymSequenceCons n _ tl) = visit n (loop tl) + loop (SymSequenceAppend n xs ys) = visit n (loop xs . loop ys) + loop (SymSequenceMerge n _ xs ys) = visit n (loop xs . loop ys) + +ppSeqNonce :: Nonce GlobalNonceGenerator a -> Doc ann +ppSeqNonce n = "s" <> PP.viaShow (indexValue n) + +prettyAux :: + IsExpr (SymExpr sym) => + (a -> Doc ann) -> + Map (Nonce GlobalNonceGenerator a) Integer -> + SymSequence sym a -> + State (Map (Nonce GlobalNonceGenerator a) (Doc ann)) (Doc ann) +prettyAux ppa occMap = goTop + where + goTop SymSequenceNil = pure (PP.list []) + goTop (SymSequenceCons _ v tl) = pp [] [v] [tl] + goTop (SymSequenceAppend _ xs ys) = pp [] [] [xs,ys] + goTop (SymSequenceMerge _ p xs ys) = + do xd <- pp [] [] [xs] + yd <- pp [] [] [ys] + pure $ {- PP.group $ -} PP.hang 2 $ PP.vsep + [ "if" PP.<+> printSymExpr p + , "then" PP.<+> xd + , "else" PP.<+> yd + ] + + visit n s = + do dm <- get + case Map.lookup n dm of + Just _ -> return () + Nothing -> + do d <- goTop s + modify (Map.insert n d) + return (ppSeqNonce n) + + finalize [] = PP.list [] + finalize [x] = x + finalize xs = PP.sep (PP.punctuate (PP.space <> "<>") (reverse xs)) + + elemSeq rs = PP.list (map ppa (reverse rs)) + + addSeg segs [] seg = (seg : segs) + addSeg segs rs seg = (seg : elemSeq rs : segs) + + -- @pp@ accumulates both "segments" of sequences (segs) + -- and individual values (rs) to be output. Both are + -- in reversed order. Segments represent sequences + -- and must be combined with the append operator, + -- and rs represent individual elements that must be combined + -- with cons (or, in actuality, list syntax with brackets and commas). + + -- @pp@ works over a list of SymSequence values, which represent a worklist + -- of segments to process. Morally, the invariant of @pp@ is that the + -- arguments always represent the same sequence, which is computed as + -- @concat (reverse segs) ++ reverse rs ++ concat ss@ + + pp segs [] [] = pure (finalize segs) + pp segs rs [] = pure (finalize ( elemSeq rs : segs )) + + pp segs rs (SymSequenceNil:ss) = pp segs rs ss + + pp segs rs (s@(SymSequenceCons n v tl) : ss) + | Just i <- Map.lookup n occMap, i > 1 + = do x <- visit n s + pp (addSeg segs rs x) [] ss + + | otherwise + = pp segs (v : rs) (tl : ss) + + pp segs rs (s@(SymSequenceAppend n xs ys) : ss) + | Just i <- Map.lookup n occMap, i > 1 + = do x <- visit n s + pp (addSeg segs rs x) [] ss + + | otherwise + = pp segs rs (xs:ys:ss) + + pp segs rs (s@(SymSequenceMerge n _ _ _) : ss) + = do x <- visit n s + pp (addSeg segs rs x) [] ss diff --git a/crucible/src/Lang/Crucible/Types.hs b/crucible/src/Lang/Crucible/Types.hs index 6a89ac324..6fda59456 100644 --- a/crucible/src/Lang/Crucible/Types.hs +++ b/crucible/src/Lang/Crucible/Types.hs @@ -64,6 +64,7 @@ module Lang.Crucible.Types , RecursiveType , IntrinsicType , VectorType + , SequenceType , StructType , VariantType , ReferenceType @@ -165,8 +166,17 @@ data CrucibleType where -- The Maybe type lifted into crucible expressions MaybeType :: CrucibleType -> CrucibleType - -- A finite (one-dimensional) sequence of values + + -- A finite (one-dimensional) sequence of values. Vectors are + -- optimized for random-access indexing and updating. Vectors + -- of different lengths may not be combined at join points. VectorType :: CrucibleType -> CrucibleType + + -- Sequences of values, represented as linked lists of cons cells. Sequences + -- only allow access to the front. Unlike Vectors, sequences of + -- different lengths may be combined at join points. + SequenceType :: CrucibleType -> CrucibleType + -- A structure is an aggregate type consisting of a sequence of values. -- The type of each value is known statically. StructType :: Ctx CrucibleType -> CrucibleType @@ -183,7 +193,7 @@ data CrucibleType where WordMapType :: Nat -> BaseType -> CrucibleType -- Named recursive types, named by the given symbol. To use recursive types - -- you must provide an instances of the IsRecursiveType class that gives + -- you must provide an instance of the IsRecursiveType class that gives -- the unfolding of this recursive type. The RollRecursive and UnrollRecursive -- operations witness the isomorphism between a recursive type and its one-step -- unrolling. Similar to Haskell's newtype, recursive types do not necessarily @@ -191,11 +201,13 @@ data CrucibleType where -- is simply a new named type which is isomorphic to its definition. RecursiveType :: Symbol -> Ctx CrucibleType -> CrucibleType - -- Named intrinsic types. Intrinsic types are a way to extend the crucible - -- type system after-the-fact and add new type implementations. - -- Core crucible provides no operations on intrinsic types; they must be provided - -- as built-in override functions. See the `IntrinsicClass` typeclass - -- and the `Intrinsic` type family defined in "Lang.Crucible.Simulator.Intrinsics". + -- Named intrinsic types. Intrinsic types are a way to extend the + -- crucible type system after-the-fact and add new type + -- implementations. Core crucible provides no operations on + -- intrinsic types; they must be provided as built-in override + -- functions, or via the language extension mechanism. See the + -- `IntrinsicClass` typeclass and the `Intrinsic` type family + -- defined in "Lang.Crucible.Simulator.Intrinsics". -- -- The context of crucible types are type arguments to the intrinsic type. IntrinsicType :: Symbol -> Ctx CrucibleType -> CrucibleType @@ -273,9 +285,16 @@ type NatType = 'NatType -- ^ @:: 'CrucibleType'@. -- | A variant is a disjoint union of the types listed in the context. type VariantType = 'VariantType -- ^ @:: 'Ctx' 'CrucibleType' -> 'CrucibleType'@. --- | A finite (one-dimensional) sequence of values. +-- | A finite (one-dimensional) sequence of values. Vectors are +-- optimized for random-access indexing and updating. Vectors +-- of different lengths may not be combined at join points. type VectorType = 'VectorType -- ^ @:: 'CrucibleType' -> 'CrucibleType'@. +-- | Sequences of values, represented as linked lists of cons cells. Sequences +-- only allow access to the front. Unlike Vectors, sequences of +-- different lengths may be combined at join points. +type SequenceType = 'SequenceType -- ^ @:: 'CrucibleType' -> 'CrucibleType'@. + -- | A finite map from bitvector values to the given Crucible type. -- The 'Nat' index gives the width of the bitvector values used to -- index the map. @@ -354,11 +373,12 @@ data TypeRepr (tp::CrucibleType) where -> !(TypeRepr ret) -> TypeRepr (FunctionHandleType ctx ret) - MaybeRepr :: !(TypeRepr tp) -> TypeRepr (MaybeType tp) - VectorRepr :: !(TypeRepr tp) -> TypeRepr (VectorType tp) - StructRepr :: !(CtxRepr ctx) -> TypeRepr (StructType ctx) + MaybeRepr :: !(TypeRepr tp) -> TypeRepr (MaybeType tp) + SequenceRepr:: !(TypeRepr tp) -> TypeRepr (SequenceType tp) + VectorRepr :: !(TypeRepr tp) -> TypeRepr (VectorType tp) + StructRepr :: !(CtxRepr ctx) -> TypeRepr (StructType ctx) VariantRepr :: !(CtxRepr ctx) -> TypeRepr (VariantType ctx) - ReferenceRepr :: TypeRepr a -> TypeRepr (ReferenceType a) + ReferenceRepr :: !(TypeRepr a) -> TypeRepr (ReferenceType a) WordMapRepr :: (1 <= n) => !(NatRepr n) @@ -418,6 +438,9 @@ instance KnownRepr FloatPrecisionRepr ps => KnownRepr TypeRepr (IEEEFloatType ps instance KnownRepr TypeRepr tp => KnownRepr TypeRepr (VectorType tp) where knownRepr = VectorRepr knownRepr +instance KnownRepr TypeRepr tp => KnownRepr TypeRepr (SequenceType tp) where + knownRepr = SequenceRepr knownRepr + instance KnownRepr TypeRepr tp => KnownRepr TypeRepr (MaybeType tp) where knownRepr = MaybeRepr knownRepr