@@ -44,7 +44,7 @@ import Data.List (foldl1', intersperse, partition)
44
44
import Data.List.NonEmpty qualified as NonEmpty
45
45
import Data.Map (Map )
46
46
import Data.Map qualified as Map
47
- import Data.Maybe (catMaybes , fromMaybe )
47
+ import Data.Maybe (catMaybes , fromMaybe , mapMaybe )
48
48
import Data.Sequence (Seq (.. ), pattern (:<|) )
49
49
import Data.Sequence qualified as Seq
50
50
import Data.Set (Set )
@@ -161,17 +161,20 @@ data EquationState = EquationState
161
161
, cache :: SimplifierCache
162
162
}
163
163
164
- data SimplifierCache = SimplifierCache { llvm , equations :: Map Term Term }
164
+ data SimplifierCache = SimplifierCache { llvm , equations , pathConditions :: Map Term Term }
165
165
deriving stock (Show )
166
166
167
167
instance Semigroup SimplifierCache where
168
168
cache1 <> cache2 =
169
- SimplifierCache (cache1. llvm <> cache2. llvm) (cache1. equations <> cache2. equations)
169
+ SimplifierCache
170
+ (cache1. llvm <> cache2. llvm)
171
+ (cache1. equations <> cache2. equations)
172
+ (cache1. pathConditions <> cache2. pathConditions)
170
173
171
174
instance Monoid SimplifierCache where
172
- mempty = SimplifierCache mempty mempty
175
+ mempty = SimplifierCache mempty mempty mempty
173
176
174
- data CacheTag = LLVM | Equations
177
+ data CacheTag = LLVM | Equations | PathConditions
175
178
deriving stock (Show )
176
179
177
180
instance ContextFor CacheTag where
@@ -192,9 +195,27 @@ startState cache known =
192
195
, recursionStack = []
193
196
, changed = False
194
197
, predicates = known
195
- , cache
198
+ , -- replacements from predicates are rebuilt from the path conditions every time
199
+ cache = cache{pathConditions = buildReplacements known}
196
200
}
197
201
202
+ buildReplacements :: Set Predicate -> Map Term Term
203
+ buildReplacements = Map. fromList . mapMaybe toReplacement . Set. elems
204
+ where
205
+ toReplacement :: Predicate -> Maybe (Term , Term )
206
+ toReplacement = \ case
207
+ Predicate (EqualsInt (v@ DomainValue {}) t) -> Just (t, v)
208
+ Predicate (EqualsInt t (v@ DomainValue {})) -> Just (t, v)
209
+ Predicate (EqualsBool (v@ DomainValue {}) t) -> Just (t, v)
210
+ Predicate (EqualsBool t (v@ DomainValue {})) -> Just (t, v)
211
+ _otherwise -> Nothing
212
+
213
+ cacheReset :: Monad io => EquationT io ()
214
+ cacheReset = eqState $ do
215
+ st@ EquationState {predicates, cache} <- get
216
+ let newCache = cache{equations = mempty , pathConditions = buildReplacements predicates}
217
+ put st{cache = newCache}
218
+
198
219
eqState :: Monad io => StateT EquationState io a -> EquationT io a
199
220
eqState = EquationT . lift . lift
200
221
@@ -237,6 +258,7 @@ popRecursion = do
237
258
else eqState $ put s{recursionStack = tail s. recursionStack}
238
259
239
260
toCache :: LoggerMIO io => CacheTag -> Term -> Term -> EquationT io ()
261
+ toCache PathConditions _ _ = pure () -- never adding to the replacements
240
262
toCache LLVM orig result = eqState . modify $
241
263
\ s -> s{cache = s. cache{llvm = Map. insert orig result s. cache. llvm}}
242
264
toCache Equations orig result = eqState $ do
@@ -261,6 +283,7 @@ fromCache tag t = eqState $ do
261
283
s <- get
262
284
case tag of
263
285
LLVM -> pure $ Map. lookup t s. cache. llvm
286
+ PathConditions -> pure $ Map. lookup t s. cache. pathConditions
264
287
Equations -> do
265
288
case Map. lookup t s. cache. equations of
266
289
Nothing -> pure Nothing
@@ -377,10 +400,14 @@ iterateEquations direction preference startTerm = do
377
400
-- NB llvmSimplify is idempotent. No need to iterate if
378
401
-- the equation evaluation does not change the term any more.
379
402
resetChanged
403
+ -- apply syntactic replacements of terms by domain values from path condition
404
+ replacedTerm <-
405
+ let simp = cached PathConditions $ traverseTerm BottomUp simp pure
406
+ in simp llvmResult
380
407
-- evaluate functions and simplify (recursively at each level)
381
408
newTerm <-
382
409
let simp = cached Equations $ traverseTerm direction simp (applyHooksAndEquations preference)
383
- in simp llvmResult
410
+ in simp replacedTerm
384
411
changeFlag <- getChanged
385
412
if changeFlag
386
413
then checkForLoop newTerm >> resetChanged >> go newTerm
@@ -913,8 +940,7 @@ applyEquation term rule =
913
940
unless (null ensuredConditions) $ do
914
941
withContextFor Equations . logMessage $
915
942
(" New ensured condition from evaluation, invalidating cache" :: Text )
916
- lift . eqState . modify $
917
- \ s -> s{cache = s. cache{equations = mempty }}
943
+ lift cacheReset
918
944
pure $ substituteInTerm subst rule. rhs
919
945
where
920
946
filterOutKnownConstraints :: Set Predicate -> [Predicate ] -> EquationT io [Predicate ]
0 commit comments