Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crucible-llvm: Refactor and export override pipe-fitting code #1193

Merged
merged 8 commits into from
Mar 27, 2024
1 change: 1 addition & 0 deletions crucible-llvm/crucible-llvm.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ library
Lang.Crucible.LLVM.Extension
Lang.Crucible.LLVM.Globals
Lang.Crucible.LLVM.Intrinsics
Lang.Crucible.LLVM.Intrinsics.Cast
Lang.Crucible.LLVM.Intrinsics.Libc
Lang.Crucible.LLVM.Intrinsics.LLVM
Lang.Crucible.LLVM.MalformedLLVMModule
Expand Down
120 changes: 120 additions & 0 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Cast.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
-- |
-- Module : Lang.Crucible.LLVM.Intrinsics.Cast
-- Description : Cast between bitvectors and pointers in signatures
-- Copyright : (c) Galois, Inc 2024
-- License : BSD3
-- Maintainer : Langston Barrett <[email protected]>
-- Stability : provisional
--
-- The built-in overrides in "Lang.Crucible.LLVM.Intrinsics.Libc" and
-- "Lang.Crucible.LLVM.Intrinsics.LLVM" frequently take arguments of type
-- 'Lang.Crucible.Types.BVType', but at runtime everything is represented as an
-- 'Lang.Crucible.LLVM.MemModel.Pointer.LLVMPtr'. This module contains helpers
-- for \"casting\" between pointers and bitvectors.
------------------------------------------------------------------------

{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Lang.Crucible.LLVM.Intrinsics.Cast
( ValCastError
, printValCastError
, ArgCast(applyArgCast)
, ValCast(applyValCast)
, castLLVMArgs
, castLLVMRet
) where

import Control.Monad.IO.Class (liftIO)
import Control.Lens

import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some (Some(Some))
import Data.Parameterized.TraversableFC (fmapFC)

import Lang.Crucible.Backend
import Lang.Crucible.Simulator.OverrideSim
import Lang.Crucible.Simulator.RegMap
import Lang.Crucible.Types

import Lang.Crucible.LLVM.MemModel

data ValCastError
= -- | Mismatched number of arguments ('castLLVMArgs') or struct fields
-- ('castLLVMRet').
MismatchedShape
-- | Can\'t cast between these types
| ValCastError (Some TypeRepr) (Some TypeRepr)

-- | Turn a 'ValCastError' into a human-readable message (lines).
printValCastError :: ValCastError -> [String]
printValCastError =
\case
MismatchedShape -> ["argument shape mismatch"]
ValCastError (Some ret) (Some ret') ->
[ "Cannot cast types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
]

-- | A function to (infallibly) cast between 'Ctx.Assignment's of 'RegEntry's.
newtype ArgCast p sym ext args args' =
ArgCast { applyArgCast :: (forall rtp l a.
Ctx.Assignment (RegEntry sym) args ->
OverrideSim p sym ext rtp l a (Ctx.Assignment (RegEntry sym) args')) }

-- | A function to (infallibly) cast a value of types @tp@ to @tp'@.
newtype ValCast p sym ext tp tp' =
ValCast { applyValCast :: (forall rtp l a.
RegValue sym tp ->
OverrideSim p sym ext rtp l a (RegValue sym tp')) }

-- | Attempt to construct a function to cast between 'Ctx.Assignment's of
-- 'RegEntry's.
castLLVMArgs :: forall p sym ext bak args args'.
IsSymBackend sym bak =>
bak ->
CtxRepr args' ->
CtxRepr args ->
Either ValCastError (ArgCast p sym ext args args')
castLLVMArgs _ Ctx.Empty Ctx.Empty =
Right (ArgCast (\_ -> return Ctx.Empty))
castLLVMArgs bak (rest' Ctx.:> tp') (rest Ctx.:> tp) =
do ValCast f <- castLLVMRet bak tp tp'
ArgCast fs <- castLLVMArgs bak rest' rest
Right (ArgCast
(\(xs Ctx.:> x) -> do
xs' <- fs xs
x' <- f (regValue x)
pure (xs' Ctx.:> RegEntry tp' x')))
castLLVMArgs _ _ _ = Left MismatchedShape

-- | Attempt to construct a function to cast values of type @ret@ to type
-- @ret'@.
castLLVMRet ::
IsSymBackend sym bak =>
bak ->
TypeRepr ret ->
TypeRepr ret' ->
Either ValCastError (ValCast p sym ext ret ret')
castLLVMRet bak (BVRepr w) (LLVMPointerRepr w')
| Just Refl <- testEquality w w'
= Right (ValCast (liftIO . llvmPointer_bv (backendGetSym bak)))
castLLVMRet bak (LLVMPointerRepr w) (BVRepr w')
| Just Refl <- testEquality w w'
= Right (ValCast (liftIO . projectLLVM_bv bak))
castLLVMRet bak (VectorRepr tp) (VectorRepr tp')
= do ValCast f <- castLLVMRet bak tp tp'
Right (ValCast (traverse f))
castLLVMRet bak (StructRepr ctx) (StructRepr ctx')
= do ArgCast tf <- castLLVMArgs bak ctx' ctx
Right (ValCast (\vals ->
let vals' = Ctx.zipWith (\tp (RV v) -> RegEntry tp v) ctx vals in
fmapFC (\x -> RV (regValue x)) <$> tf vals'))

castLLVMRet _bak ret ret'
| Just Refl <- testEquality ret ret'
= Right (ValCast return)
castLLVMRet _bak ret ret' = Left (ValCastError (Some ret) (Some ret'))
88 changes: 16 additions & 72 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ import qualified System.Info as Info
import qualified ABI.Itanium as ABI
import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some (Some(..))
import Data.Parameterized.TraversableFC (fmapFC)

import Lang.Crucible.Backend
import Lang.Crucible.CFG.Common (GlobalVar)
Expand All @@ -78,6 +77,7 @@ import Lang.Crucible.LLVM.Eval (callStackFromMemVar)
import Lang.Crucible.LLVM.Globals (registerFunPtr)
import Lang.Crucible.LLVM.MemModel
import Lang.Crucible.LLVM.MemModel.CallStack (CallStack)
import qualified Lang.Crucible.LLVM.Intrinsics.Cast as Cast
import Lang.Crucible.LLVM.Translation.Monad
import Lang.Crucible.LLVM.Translation.Types

Expand Down Expand Up @@ -199,74 +199,6 @@ apply this special case to other override functions (e.g.,
------------------------------------------------------------------------
-- ** register_llvm_override

newtype ArgTransformer p sym ext args args' =
ArgTransformer { applyArgTransformer :: (forall rtp l a.
Ctx.Assignment (RegEntry sym) args ->
OverrideSim p sym ext rtp l a (Ctx.Assignment (RegEntry sym) args')) }

newtype ValTransformer p sym ext tp tp' =
ValTransformer { applyValTransformer :: (forall rtp l a.
RegValue sym tp ->
OverrideSim p sym ext rtp l a (RegValue sym tp')) }

transformLLVMArgs :: forall m p sym ext bak args args'.
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
bak ->
CtxRepr args' ->
CtxRepr args ->
m (ArgTransformer p sym ext args args')
transformLLVMArgs _fnName _ Ctx.Empty Ctx.Empty =
return (ArgTransformer (\_ -> return Ctx.Empty))
transformLLVMArgs fnName bak (rest' Ctx.:> tp') (rest Ctx.:> tp) = do
return (ArgTransformer
(\(xs Ctx.:> x) ->
do (ValTransformer f) <- transformLLVMRet fnName bak tp tp'
(ArgTransformer fs) <- transformLLVMArgs fnName bak rest' rest
xs' <- fs xs
x' <- RegEntry tp' <$> f (regValue x)
pure (xs' Ctx.:> x')))
transformLLVMArgs fnName _ _ _ =
panic "Intrinsics.transformLLVMArgs"
[ "transformLLVMArgs: argument shape mismatch!"
, "in function: " ++ Text.unpack (functionName fnName)
]

transformLLVMRet ::
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
bak ->
TypeRepr ret ->
TypeRepr ret' ->
m (ValTransformer p sym ext ret ret')
transformLLVMRet _fnName bak (BVRepr w) (LLVMPointerRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . llvmPointer_bv (backendGetSym bak)))
transformLLVMRet _fnName bak (LLVMPointerRepr w) (BVRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . projectLLVM_bv bak))
transformLLVMRet fnName bak (VectorRepr tp) (VectorRepr tp')
= do ValTransformer f <- transformLLVMRet fnName bak tp tp'
return (ValTransformer (traverse f))
transformLLVMRet fnName bak (StructRepr ctx) (StructRepr ctx')
= do ArgTransformer tf <- transformLLVMArgs fnName bak ctx' ctx
return (ValTransformer (\vals ->
let vals' = Ctx.zipWith (\tp (RV v) -> RegEntry tp v) ctx vals in
fmapFC (\x -> RV (regValue x)) <$> tf vals'))

transformLLVMRet _fnName _bak ret ret'
| Just Refl <- testEquality ret ret'
= return (ValTransformer return)
transformLLVMRet fnName _bak ret ret'
= panic "Intrinsics.transformLLVMRet"
[ "Cannot transform types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
, "in function: " ++ Text.unpack (functionName fnName)
]

-- | Do some pipe-fitting to match a Crucible override function into the shape
-- expected by the LLVM calling convention. This basically just coerces
-- between values of @BVType w@ and values of @LLVMPointerType w@.
Expand All @@ -283,11 +215,23 @@ build_llvm_override ::
OverrideSim p sym ext rtp l a (Override p sym ext args' ret')
build_llvm_override fnm args ret args' ret' llvmOverride =
ovrWithBackend $ \bak ->
do fargs <- transformLLVMArgs fnm bak args args'
fret <- transformLLVMRet fnm bak ret ret'
do fargs <-
case Cast.castLLVMArgs bak args args' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
fret <-
case Cast.castLLVMRet bak ret ret' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
return $ mkOverride' fnm ret' $
do RegMap xs <- getOverrideArgs
applyValTransformer fret =<< llvmOverride =<< applyArgTransformer fargs xs
Cast.applyValCast fret =<< llvmOverride =<< Cast.applyArgCast fargs xs

polymorphic1_llvm_override :: forall p sym arch wptr l a rtp.
(IsSymInterface sym, HasLLVMAnn sym, HasPtrWidth wptr) =>
Expand Down
Loading