diff --git a/docs/tree_incrementality.py b/docs/tree_incrementality.py new file mode 100644 index 00000000000..a509fb6aec4 --- /dev/null +++ b/docs/tree_incrementality.py @@ -0,0 +1,35 @@ +class TreeIncrementalSolver: + def __init__(self, Solver, max_solvers: int): + self.Solver = Solver + self.max_solvers = max_solvers + self.recently_used = Queue() + self.recycled_solvers = [] + + def reuseOrCreateZ3(self, query): + solver_from_scratch_cost = len(query) + min_cost, min_solver = solver_from_scratch_cost, None + for solver in self.recycled_solvers: + delta = distance(solver.query, query) + if delta < min_cost: + min_cost, min_solver = delta, solver + if min_solver is None: + return self.Solver() + return min_solver + + def find_suitable_solver(self, query): + for solver in self.recently_used: + if solver.query is subsetOf(query): + self.recently_used.remove(solver) + return solver + if len(self.recently_used) < self.max_solvers: + return self.reuseOrCreateZ3(query) + return self.recently_used.pop_back() + + def check_sat(self, query): + solver = self.find_suitable_solver(query) + push(solver, query) + self.recently_used.push_front(solver) + solver.check_sat() + + def recycle_solver(self, id): + raise NotImplementedError() diff --git a/include/klee/Solver/Solver.h b/include/klee/Solver/Solver.h index 30b344efecb..c39d8a2ae39 100644 --- a/include/klee/Solver/Solver.h +++ b/include/klee/Solver/Solver.h @@ -191,6 +191,10 @@ class Solver { virtual char *getConstraintLog(const Query &query); virtual void setCoreSolverTimeout(time::Span timeout); + + /// @brief Notify the solver that the state with specified id has been + /// terminated + void notifyStateTermination(std::uint32_t id); }; /* *** */ @@ -264,6 +268,11 @@ std::unique_ptr createCoreSolver(CoreSolverType cst); std::unique_ptr createConcretizingSolver(std::unique_ptr s, AddressGenerator *addressGenerator); + +/// Return a list of all unique symbolic objects referenced by the +/// given Query. +void findSymbolicObjects(const Query &query, + std::vector &results); } // namespace klee #endif /* KLEE_SOLVER_H */ diff --git a/include/klee/Solver/SolverCmdLine.h b/include/klee/Solver/SolverCmdLine.h index 7fbfac03940..6bd0c285bbd 100644 --- a/include/klee/Solver/SolverCmdLine.h +++ b/include/klee/Solver/SolverCmdLine.h @@ -50,6 +50,8 @@ extern llvm::cl::opt CoreSolverOptimizeDivides; extern llvm::cl::opt UseAssignmentValidatingSolver; +extern llvm::cl::opt MaxSolversApproxTreeInc; + /// The different query logging solvers that can be switched on/off enum QueryLoggingSolverType { ALL_KQUERY, ///< Log all queries in .kquery (KQuery) format @@ -65,6 +67,7 @@ enum CoreSolverType { METASMT_SOLVER, DUMMY_SOLVER, Z3_SOLVER, + Z3_TREE_SOLVER, NO_SOLVER }; diff --git a/include/klee/Solver/SolverImpl.h b/include/klee/Solver/SolverImpl.h index bd5937ed11e..c3d7ed40a2f 100644 --- a/include/klee/Solver/SolverImpl.h +++ b/include/klee/Solver/SolverImpl.h @@ -22,7 +22,7 @@ class ExecutionState; class Expr; struct Query; -/// SolverImpl - Abstract base clase for solver implementations. +/// SolverImpl - Abstract base class for solver implementations. class SolverImpl { public: SolverImpl() = default; @@ -119,6 +119,8 @@ class SolverImpl { } virtual void setCoreSolverTimeout(time::Span timeout){}; + + virtual void notifyStateTermination(std::uint32_t id){}; }; } // namespace klee diff --git a/include/klee/Solver/SolverUtil.h b/include/klee/Solver/SolverUtil.h index d61c188cc0b..e893b811967 100644 --- a/include/klee/Solver/SolverUtil.h +++ b/include/klee/Solver/SolverUtil.h @@ -41,6 +41,9 @@ enum class Validity { True = 1, False = -1, Unknown = 0 }; struct SolverQueryMetaData { /// @brief Costs for all queries issued for this state time::Span queryCost; + + /// @brief Caller state id + std::uint32_t id = 0; }; struct Query { diff --git a/lib/Core/ExecutionState.cpp b/lib/Core/ExecutionState.cpp index 90a2c9ac457..11ce0390f43 100644 --- a/lib/Core/ExecutionState.cpp +++ b/lib/Core/ExecutionState.cpp @@ -179,7 +179,9 @@ ExecutionState::ExecutionState(const ExecutionState &state) returnValue(state.returnValue), gepExprBases(state.gepExprBases), prevTargets_(state.prevTargets_), targets_(state.targets_), prevHistory_(state.prevHistory_), history_(state.history_), - isTargeted_(state.isTargeted_) {} + isTargeted_(state.isTargeted_) { + queryMetaData.id = state.id; +} ExecutionState *ExecutionState::branch() { depth++; diff --git a/lib/Core/ExecutionState.h b/lib/Core/ExecutionState.h index e0d432239bf..fb811cfb8a6 100644 --- a/lib/Core/ExecutionState.h +++ b/lib/Core/ExecutionState.h @@ -438,7 +438,10 @@ class ExecutionState { bool visited(KBlock *block) const; std::uint32_t getID() const { return id; }; - void setID() { id = nextID++; }; + void setID() { + id = nextID++; + queryMetaData.id = id; + }; llvm::BasicBlock *getInitPCBlock() const; llvm::BasicBlock *getPrevPCBlock() const; llvm::BasicBlock *getPCBlock() const; diff --git a/lib/Core/TimingSolver.cpp b/lib/Core/TimingSolver.cpp index edae9a3a5b9..f241b804796 100644 --- a/lib/Core/TimingSolver.cpp +++ b/lib/Core/TimingSolver.cpp @@ -351,3 +351,7 @@ TimingSolver::getRange(const ConstraintSet &constraints, ref expr, metaData.queryCost += timer.delta(); return result; } + +void TimingSolver::notifyStateTermination(std::uint32_t id) { + solver->notifyStateTermination(id); +} diff --git a/lib/Core/TimingSolver.h b/lib/Core/TimingSolver.h index fbf65c4a56c..7692d503778 100644 --- a/lib/Core/TimingSolver.h +++ b/lib/Core/TimingSolver.h @@ -49,6 +49,10 @@ class TimingSolver { return solver->getConstraintLog(query); } + /// @brief Notify the solver that the state with specified id has been + /// terminated + void notifyStateTermination(std::uint32_t id); + bool evaluate(const ConstraintSet &, ref, PartialValidity &result, SolverQueryMetaData &metaData, bool produceValidityCore = false); diff --git a/lib/Solver/AssignmentValidatingSolver.cpp b/lib/Solver/AssignmentValidatingSolver.cpp index dde0f1a23d5..486a7f4a725 100644 --- a/lib/Solver/AssignmentValidatingSolver.cpp +++ b/lib/Solver/AssignmentValidatingSolver.cpp @@ -142,14 +142,8 @@ bool AssignmentValidatingSolver::check(const Query &query, return true; } - ExprHashSet expressions; - assert(!query.containsSymcretes()); - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; assert(isa(result)); diff --git a/lib/Solver/CexCachingSolver.cpp b/lib/Solver/CexCachingSolver.cpp index da3303359de..6f631ddef14 100644 --- a/lib/Solver/CexCachingSolver.cpp +++ b/lib/Solver/CexCachingSolver.cpp @@ -286,17 +286,9 @@ bool CexCachingSolver::computeValidity(const Query &query, PartialValidity &result) { TimerStatIncrementer t(stats::cexCacheTime); ref a; - if (!getResponse(query.withFalse(), a)) + ref q; + if (!computeValue(query, q)) return false; - assert(isa(a) && "computeValidity() must have assignment"); - - ref q = cast(a)->evaluate(query.expr); - - if (!isa(q) && solver->impl->computeValue(query, q)) - return false; - - assert(isa(q) && - "assignment evaluation did not result in constant"); if (cast(q)->isTrue()) { if (!getResponse(query, a)) @@ -343,10 +335,13 @@ bool CexCachingSolver::computeValue(const Query &query, ref &result) { TimerStatIncrementer t(stats::cexCacheTime); ref a; - if (!getResponse(query.withFalse(), a)) - return false; - assert(isa(a) && "computeValue() must have assignment"); - result = cast(a)->evaluate(query.expr); + result = ConstantExpr::create(1, query.expr->getWidth()); + if (!query.constraints.cs().empty()) { + if (!getResponse(query.withFalse(), a)) + return false; + assert(isa(a) && "computeValue() must have assignment"); + result = cast(a)->evaluate(query.expr); + } if (!isa(result) && solver->impl->computeValue(query, result)) return false; diff --git a/lib/Solver/CoreSolver.cpp b/lib/Solver/CoreSolver.cpp index b6de1024b86..e6d54a8c4d7 100644 --- a/lib/Solver/CoreSolver.cpp +++ b/lib/Solver/CoreSolver.cpp @@ -27,7 +27,14 @@ DISABLE_WARNING_POP namespace klee { +std::unique_ptr createZ3Solver(bool isTreeSolver, Z3BuilderType type) { + if (isTreeSolver) + return std::make_unique(type, MaxSolversApproxTreeInc); + return std::make_unique(type); +} + std::unique_ptr createCoreSolver(CoreSolverType cst) { + bool isTreeSolver = false; switch (cst) { case STP_SOLVER: #ifdef ENABLE_STP @@ -54,15 +61,17 @@ std::unique_ptr createCoreSolver(CoreSolverType cst) { #endif case DUMMY_SOLVER: return createDummySolver(); + case Z3_TREE_SOLVER: + isTreeSolver = true; case Z3_SOLVER: #ifdef ENABLE_Z3 klee_message("Using Z3 solver backend"); #ifdef ENABLE_FP klee_message("Using Z3 bitvector builder"); - return std::make_unique(KLEE_BITVECTOR); + return createZ3Solver(isTreeSolver, KLEE_BITVECTOR); #else klee_message("Using Z3 core builder"); - return std::make_unique(KLEE_CORE); + return createZ3Solver(isTreeSolver, KLEE_CORE); #endif #else klee_message("Not compiled with Z3 support"); diff --git a/lib/Solver/IncompleteSolver.cpp b/lib/Solver/IncompleteSolver.cpp index 5fd7c74b3b2..2262423a561 100644 --- a/lib/Solver/IncompleteSolver.cpp +++ b/lib/Solver/IncompleteSolver.cpp @@ -119,13 +119,8 @@ bool StagedSolverImpl::computeInitialValues( } bool StagedSolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; bool hasSolution; diff --git a/lib/Solver/Solver.cpp b/lib/Solver/Solver.cpp index ec30aa9b19f..4aeff5adc8c 100644 --- a/lib/Solver/Solver.cpp +++ b/lib/Solver/Solver.cpp @@ -162,6 +162,10 @@ bool Solver::check(const Query &query, ref &queryResult) { return impl->check(query, queryResult); } +void Solver::notifyStateTermination(std::uint32_t id) { + impl->notifyStateTermination(id); +} + static std::pair, ref> getDefaultRange() { return std::make_pair(ConstantExpr::create(0, 64), ConstantExpr::create(0, 64)); @@ -327,6 +331,15 @@ bool Query::containsSizeSymcretes() const { return false; } +void klee::findSymbolicObjects(const Query &query, + std::vector &results) { + ExprHashSet expressions; + expressions.insert(query.constraints.cs().begin(), + query.constraints.cs().end()); + expressions.insert(query.expr); + findSymbolicObjects(expressions.begin(), expressions.end(), results); +} + void Query::dump() const { constraints.dump(); llvm::errs() << "Query [\n"; diff --git a/lib/Solver/SolverCmdLine.cpp b/lib/Solver/SolverCmdLine.cpp index 0c96fb12634..0f51525d535 100644 --- a/lib/Solver/SolverCmdLine.cpp +++ b/lib/Solver/SolverCmdLine.cpp @@ -123,6 +123,13 @@ cl::opt UseAssignmentValidatingSolver( cl::desc("Debug the correctness of generated assignments (default=false)"), cl::cat(SolvingCat)); +cl::opt + MaxSolversApproxTreeInc("max-solvers-approx-tree-inc", + cl::desc("Maximum size of the Z3 solver pool for " + "approximating tree incrementality." + " Set to 0 to disable (default=0)"), + cl::init(0), cl::cat(SolvingCat)); + void KCommandLine::HideOptions(llvm::cl::OptionCategory &Category) { StringMap &map = cl::getRegisteredOptions(); @@ -196,11 +203,12 @@ cl::opt MetaSMTBackend( cl::opt CoreSolverToUse( "solver-backend", cl::desc("Specifiy the core solver backend to use"), - cl::values(clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR), - clEnumValN(METASMT_SOLVER, "metasmt", - "metaSMT" METASMT_IS_DEFAULT_STR), - clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"), - clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR)), + cl::values( + clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR), + clEnumValN(METASMT_SOLVER, "metasmt", "metaSMT" METASMT_IS_DEFAULT_STR), + clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"), + clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR), + clEnumValN(Z3_TREE_SOLVER, "z3-tree", "Z3 tree-incremental solver")), cl::init(DEFAULT_CORE_SOLVER), cl::cat(SolvingCat)); cl::opt DebugCrossCheckCoreSolverWith( diff --git a/lib/Solver/SolverImpl.cpp b/lib/Solver/SolverImpl.cpp index 033f1d6d192..2768ce3ce9a 100644 --- a/lib/Solver/SolverImpl.cpp +++ b/lib/Solver/SolverImpl.cpp @@ -42,13 +42,8 @@ bool SolverImpl::computeValidity(const Query &query, } bool SolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; bool hasSolution; diff --git a/lib/Solver/Z3Solver.cpp b/lib/Solver/Z3Solver.cpp index cfcb5d6b574..a3ff7434935 100644 --- a/lib/Solver/Z3Solver.cpp +++ b/lib/Solver/Z3Solver.cpp @@ -12,7 +12,12 @@ #include "klee/Support/FileHandling.h" #include "klee/Support/OptionCategories.h" +#include #include +#include +#include +#include +#include #ifdef ENABLE_Z3 @@ -30,9 +35,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include -#include - namespace { // NOTE: Very useful for debugging Z3 behaviour. These files can be given to // the z3 binary to replay all Z3 API calls using its `-log` option. @@ -67,20 +69,301 @@ DISABLE_WARNING_POP namespace klee { -class Z3SolverImpl : public SolverImpl { +typedef std::unordered_set FrameIds; + +template +void extend(std::vector<_Tp, _Alloc> &ths, + const std::vector<_Tp, _Alloc> &other) { + ths.reserve(ths.size() + other.size()); + ths.insert(ths.end(), other.begin(), other.end()); +} + +template > class inc_vector; +typedef inc_vector> ConstraintFrames; + +template class inc_vector { + friend void dump(const ConstraintFrames &); + +public: + /// It is public, so that all vector operations are supported + /// Everything pushed to v is pushed to the last frame + std::vector<_Tp, _Alloc> v; + + using iterator = std::vector::const_iterator; + private: + std::vector frame_sizes; + // v.size() == sum(frame_sizes) + size of the fresh frame + + size_t freshFrameSize() const { + return v.size() - + std::accumulate(frame_sizes.begin(), frame_sizes.end(), 0); + } + + void take(size_t n, size_t &frames_count, size_t &frame_index) const { + size_t i = 0; + size_t c = n; + for (; i < frame_sizes.size(); i++) { + if (frame_sizes[i] > c) + break; + c -= frame_sizes[i]; + } + frames_count = c; + frame_index = i; + } + +public: + inc_vector() {} + inc_vector(const std::vector<_Tp> &constraints) : v(constraints) {} + + iterator begin() const { return frame_sizes.cbegin(); } + iterator end() const { return frame_sizes.cend(); } + + void pop(size_t popFrames) { + assert(freshFrameSize() == 0); + size_t toPop = + std::accumulate(frame_sizes.end() - popFrames, frame_sizes.end(), 0); + v.resize(v.size() - toPop); + frame_sizes.resize(frame_sizes.size() - popFrames); + } + + void push() { + auto freshSize = freshFrameSize(); + frame_sizes.push_back(freshSize); + assert(freshFrameSize() == 0); + } + + /// ensures that last frame is empty + void extend(const std::vector<_Tp, _Alloc> &other) { + assert(freshFrameSize() == 0); + // push(); + klee::extend(v, other); + push(); + } + + void takeAfter(size_t n, inc_vector<_Tp, _Alloc> &result) const { + size_t frames_count, frame_index; + take(n, frames_count, frame_index); + result = *this; + std::vector<_Tp, _Alloc>(result.v.begin() + n, result.v.end()) + .swap(result.v); + std::vector(result.frame_sizes.begin() + frame_index, + result.frame_sizes.end()) + .swap(result.frame_sizes); + if (frames_count) + result.frame_sizes[0] -= frames_count; + } + + void takeBefore(size_t n, size_t &toPop, size_t &takeFromOther) const { + take(n, takeFromOther, toPop); + toPop = frame_sizes.size() - toPop; + } +}; + +void dump(const ConstraintFrames &frames) { + llvm::errs() << "frame sizes:"; + for (auto size : frames.frame_sizes) { + llvm::errs() << " " << size; + } + llvm::errs() << "\n"; + llvm::errs() << "frames:\n"; + for (auto &x : frames.v) { + llvm::errs() << x->toString() << "\n"; + } +} + +template , + typename _Pred = std::equal_to<_Value>, + typename _Alloc = std::allocator<_Value>> +class inc_uset { +private: + using setT = std::unordered_set<_Value, _Hash, _Pred, _Alloc>; + using citerator = typename setT::const_iterator; + using idMap = std::unordered_map<_Value, FrameIds, _Hash, _Pred, _Alloc>; + setT set; + idMap ids; + size_t current_frame; + +public: + citerator cbegin() const { return set.cbegin(); } + citerator cend() const { return set.cend(); } + + void insert(const _Value &v) { + set.insert(v); + ids[v].insert(current_frame); + } + + void pop(size_t popFrames) { + current_frame -= popFrames; + idMap newIdMap; + for (auto &keyAndIds : ids) { + FrameIds newIds; + for (auto id : keyAndIds.second) + if (id <= current_frame) + newIds.insert(id); + if (newIds.empty()) + set.erase(keyAndIds.first); + else + newIdMap.insert(std::make_pair(keyAndIds.first, newIds)); + } + ids = newIdMap; + } + + void push() { current_frame++; } +}; + +template , + typename _Pred = std::equal_to<_Key>, + typename _Alloc = std::allocator>> +class inc_umap { +private: + std::unordered_map<_Key, _Tp, _Hash, _Pred, _Alloc> map; + using idMap = std::unordered_map<_Key, FrameIds, _Hash, _Pred, _Alloc>; + idMap ids; + size_t current_frame = 0; + +public: + void insert(const std::pair<_Key, _Tp> &pair) { + map.insert(pair); + ids[pair.first].insert(current_frame); + } + + _Tp &operator[](const _Key &key) { + ids[key].insert(current_frame); + return map[key]; + } + + size_t count(const _Key &key) const { return map.count(key); } + + const _Tp &at(_Key &key) const { return map.at(key); } + + void pop(size_t popFrames) { + current_frame -= popFrames; + idMap newIdMap; + for (auto &keyAndIds : ids) { + FrameIds newIds; + for (auto id : keyAndIds.second) + if (id <= current_frame) + newIds.insert(id); + if (newIds.empty()) + map.erase(keyAndIds.first); + else + newIdMap.insert(std::make_pair(keyAndIds.first, newIds)); + } + ids = newIdMap; + } + + void push() { current_frame++; } +}; + +class ConstraintQuery { +private: + // this should be used when only query is needed, se comment below + ref expr; + +public: + // KLEE Queries are validity queries i.e. + // ∀ X Constraints(X) → query(X) + // but Z3 works in terms of satisfiability so instead we ask the + // negation of the equivalent i.e. + // ∃ X Constraints(X) ∧ ¬ query(X) + // so this `constraints` field contains: Constraints(X) ∧ ¬ query(X) + ConstraintFrames constraints; + + explicit ConstraintQuery() {} + + ConstraintQuery(ConstraintFrames &frames) + : expr(Expr::createIsZero(frames.v.back())), constraints(frames) {} + + ConstraintQuery(const Query &q) : expr(q.expr) { + for (auto &constraint : q.constraints.cs()) { + constraints.v.push_back(constraint); + constraints.push(); + } + if (!q.expr->isFalse()) + constraints.v.push_back(Expr::createIsZero(q.expr)); + return; + } + + size_t size() const { return constraints.v.size(); } + + ref getOriginalQueryExpr() const { return expr; } + + std::vector gatherArrays() const; +}; + +std::vector ConstraintQuery::gatherArrays() const { + std::vector arrays; + findObjects(constraints.v.begin(), constraints.v.end(), arrays); + return arrays; +} + +void findSymbolicObjects(const ConstraintQuery &cf, + std::vector &results) { + ExprHashSet expressions; + expressions.insert(cf.constraints.v.begin(), cf.constraints.v.end()); + findSymbolicObjects(expressions.begin(), expressions.end(), results); +} + +class Z3SolverEnv { +public: + inc_vector objects; + inc_vector z3_ast_expr_constraints; + inc_umap, Z3ASTHandleHash, Z3ASTHandleCmp> + z3_ast_expr_to_klee_expr; + inc_umap + expr_to_track; + inc_umap usedArrayBytes; + + explicit Z3SolverEnv(){}; + explicit Z3SolverEnv(const std::vector &objects); + + void pop(size_t popSize); + void push(); +}; + +Z3SolverEnv::Z3SolverEnv(const std::vector &objects) + : objects(objects) {} + +void Z3SolverEnv::pop(size_t popSize) { + objects.pop(popSize); + z3_ast_expr_constraints.pop(popSize); + z3_ast_expr_to_klee_expr.pop(popSize); + expr_to_track.pop(popSize); + usedArrayBytes.pop(popSize); +} + +void Z3SolverEnv::push() { + objects.push(); + z3_ast_expr_constraints.push(); + z3_ast_expr_to_klee_expr.push(); + expr_to_track.push(); + usedArrayBytes.push(); +} + +class Z3SolverImpl : public SolverImpl { +protected: std::unique_ptr builder; + ::Z3_params solverParameters; + +private: + static size_t solvedConstraints; + Z3BuilderType builderType; time::Span timeout; SolverRunStatus runStatusCode; std::unique_ptr dumpedQueriesFile; - ::Z3_params solverParameters; // Parameter symbols ::Z3_symbol timeoutParamStrSymbol; ::Z3_symbol unsatCoreParamStrSymbol; - bool internalRunSolver(const Query &, - const std::vector *objects, +public: + virtual Z3_solver initNativeZ3(const ConstraintQuery &query, Z3_probe probe, + Z3_goal goal) = 0; + virtual void deinitNativeZ3(Z3_solver theSolver) = 0; + +private: + bool internalRunSolver(const ConstraintQuery &query, Z3SolverEnv &env, std::vector> *values, ValidityCore *validityCore, bool &hasSolution); bool validateZ3Model(::Z3_solver &theSolver, ::Z3_model &theModel); @@ -115,17 +398,45 @@ class Z3SolverImpl : public SolverImpl { const std::vector &objects, std::vector> &values, bool &hasSolution); - bool check(const Query &query, ref &result); + using SolverImpl::check; + bool check(const ConstraintQuery &query, Z3SolverEnv &env, + ref &result); bool computeValidityCore(const Query &query, ValidityCore &validityCore, bool &isValid); SolverRunStatus handleSolverResponse( - ::Z3_solver theSolver, ::Z3_lbool satisfiable, - const std::vector *objects, - std::vector> *values, - const std::unordered_map &usedArrayBytes, - bool &hasSolution); + ::Z3_solver theSolver, ::Z3_lbool satisfiable, const Z3SolverEnv &env, + std::vector> *values, bool &hasSolution); SolverRunStatus getOperationStatusCode(); }; +size_t Z3SolverImpl::solvedConstraints = 0; + +Z3_solver createNativeZ3(Z3_context ctx, Z3_params solverParameters, + const ConstraintQuery &query, Z3_probe probe, + Z3_goal goal) { + Z3_solver theSolver = nullptr; + std::vector arrays = query.gatherArrays(); + bool forceTactic = true; + for (const Array *array : arrays) { + if (isa(array->source)) { + forceTactic = false; + break; + } + } + + if (forceTactic && Z3_probe_apply(ctx, probe, goal)) { + theSolver = + Z3_mk_solver_for_logic(ctx, Z3_mk_string_symbol(ctx, "QF_AUFBV")); + } else { + theSolver = Z3_mk_solver(ctx); + } + Z3_solver_inc_ref(ctx, theSolver); + Z3_solver_set_params(ctx, theSolver, solverParameters); + return theSolver; +} + +void deleteNativeZ3(Z3_context ctx, Z3_solver theSolver) { + Z3_solver_dec_ref(ctx, theSolver); +} Z3SolverImpl::Z3SolverImpl(Z3BuilderType type) : builderType(type), runStatusCode(SOLVER_RUN_STATUS_FAILURE) { @@ -183,10 +494,55 @@ Z3SolverImpl::Z3SolverImpl(Z3BuilderType type) Z3SolverImpl::~Z3SolverImpl() { Z3_params_dec_ref(builder->ctx, solverParameters); + llvm::errs() << "Total solved constraints: " << solvedConstraints << "\n"; +} + +class Z3NonIncSolverImpl : public Z3SolverImpl { +private: + using Z3SolverImpl::check; + +public: + Z3NonIncSolverImpl(Z3BuilderType type) : Z3SolverImpl(type) {} + + /// implementation of Z3SolverImpl interface + Z3_solver initNativeZ3(const ConstraintQuery &query, Z3_probe probe, + Z3_goal goal); + void deinitNativeZ3(Z3_solver theSolver); + + /// implementation of the SolverImpl interface //TODO: return after + /// refactoring parent + // bool computeTruth(const Query &query, bool &isValid); + // bool computeValidity(const Query &query, Solver::Validity &result); + // bool computeValue(const Query &query, ref &result); + // bool computeInitialValues(const Query &query, + // const std::vector &objects, + // std::vector> + // &values, bool &hasSolution); + bool check(const Query &query, ref &result); + // bool computeValidityCore(const Query &query, ValidityCore &validityCore, + // bool &isValid); + // SolverRunStatus getOperationStatusCode(); + // void setCoreSolverTimeout(time::Span timeout); +}; + +Z3_solver Z3NonIncSolverImpl::initNativeZ3(const ConstraintQuery &query, + Z3_probe probe, Z3_goal goal) { + return createNativeZ3(builder->ctx, solverParameters, query, probe, goal); +} + +void Z3NonIncSolverImpl::deinitNativeZ3(Z3_solver theSolver) { + deleteNativeZ3(builder->ctx, theSolver); +} + +bool Z3NonIncSolverImpl::check(const Query &query, + ref &result) { + Z3SolverEnv env; + bool solver_result = check(query, env, result); + return solver_result; } Z3Solver::Z3Solver(Z3BuilderType type) - : Solver(std::make_unique(type)) {} + : Solver(std::make_unique(type)) {} char *Z3Solver::getConstraintLog(const Query &query) { return impl->getConstraintLog(query); @@ -274,8 +630,9 @@ char *Z3SolverImpl::getConstraintLog(const Query &query) { } bool Z3SolverImpl::computeTruth(const Query &query, bool &isValid) { + Z3SolverEnv env; bool hasSolution = false; // to remove compiler warning - bool status = internalRunSolver(query, /*objects=*/NULL, /*values=*/NULL, + bool status = internalRunSolver(query, /*env=*/env, /*values=*/NULL, /*validityCore=*/NULL, hasSolution); isValid = !hasSolution; return status; @@ -303,30 +660,25 @@ bool Z3SolverImpl::computeValue(const Query &query, ref &result) { bool Z3SolverImpl::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { - return internalRunSolver(query, &objects, &values, /*validityCore=*/NULL, + Z3SolverEnv env(objects); + return internalRunSolver(query, env, &values, /*validityCore=*/NULL, hasSolution); } -bool Z3SolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - assert(!query.containsSymcretes()); - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); +bool Z3SolverImpl::check(const ConstraintQuery &query, Z3SolverEnv &env, + ref &result) { + findSymbolicObjects(query, env.objects.v); std::vector> values; ValidityCore validityCore; bool hasSolution = false; - + solvedConstraints += query.size(); bool status = - internalRunSolver(query, &objects, &values, &validityCore, hasSolution); + internalRunSolver(query, env, &values, &validityCore, hasSolution); if (status) { result = hasSolution - ? (SolverResponse *)new InvalidResponse(objects, values) + ? (SolverResponse *)new InvalidResponse(env.objects.v, values) : (SolverResponse *)new ValidResponse(validityCore); } return status; @@ -335,20 +687,19 @@ bool Z3SolverImpl::check(const Query &query, ref &result) { bool Z3SolverImpl::computeValidityCore(const Query &query, ValidityCore &validityCore, bool &isValid) { + Z3SolverEnv env; bool hasSolution = false; // to remove compiler warning - bool status = internalRunSolver(query, /*objects=*/NULL, /*values=*/NULL, + bool status = internalRunSolver(query, /*env=*/env, /*values=*/NULL, &validityCore, hasSolution); isValid = !hasSolution; return status; } bool Z3SolverImpl::internalRunSolver( - const Query &query, const std::vector *objects, + const ConstraintQuery &query, Z3SolverEnv &env, std::vector> *values, ValidityCore *validityCore, bool &hasSolution) { - assert(!query.containsSymcretes()); - if (ProduceUnsatCore && validityCore) { enableUnsatCore(); } else { @@ -373,24 +724,18 @@ bool Z3SolverImpl::internalRunSolver( runStatusCode = SOLVER_RUN_STATUS_FAILURE; ConstantArrayFinder constant_arrays_in_query; - std::vector z3_ast_expr_constraints; - std::unordered_map, Z3ASTHandleHash, Z3ASTHandleCmp> - z3_ast_expr_to_klee_expr; - - std::unordered_map - expr_to_track; std::unordered_set exprs; unsigned id = 0; std::string freshName = "freshName"; - for (auto const &constraint : query.constraints.cs()) { + for (auto const &constraint : query.constraints.v) { Z3ASTHandle z3Constraint = builder->construct(constraint); if (ProduceUnsatCore && validityCore) { Z3ASTHandle p = builder->buildFreshBoolConst( (freshName + llvm::utostr(++id)).c_str()); - z3_ast_expr_to_klee_expr.insert({p, constraint}); - z3_ast_expr_constraints.push_back(p); - expr_to_track[z3Constraint] = p; + env.z3_ast_expr_to_klee_expr.insert({p, constraint}); + env.z3_ast_expr_constraints.v.push_back(p); + env.expr_to_track[z3Constraint] = p; } Z3_goal_assert(builder->ctx, goal, z3Constraint); @@ -399,13 +744,9 @@ bool Z3SolverImpl::internalRunSolver( constant_arrays_in_query.visit(constraint); } ++stats::solverQueries; - if (objects) + if (!env.objects.v.empty()) ++stats::queryCounterexamples; - Z3ASTHandle z3QueryExpr = - Z3ASTHandle(builder->construct(query.expr), builder->ctx); - constant_arrays_in_query.visit(query.expr); - for (auto const &constant_array : constant_arrays_in_query.results) { assert(builder->constant_array_assertions.count(constant_array) == 1 && "Constant array found in query, but not handled by Z3Builder"); @@ -416,15 +757,6 @@ bool Z3SolverImpl::internalRunSolver( } } - // KLEE Queries are validity queries i.e. - // ∀ X Constraints(X) → query(X) - // but Z3 works in terms of satisfiability so instead we ask the - // negation of the equivalent i.e. - // ∃ X Constraints(X) ∧ ¬ query(X) - Z3ASTHandle z3NotQueryExpr = - Z3ASTHandle(Z3_mk_not(builder->ctx, z3QueryExpr), builder->ctx); - Z3_goal_assert(builder->ctx, goal, z3NotQueryExpr); - // Assert an generated side constraints we have to this last so that all other // constraints have been traversed so we have all the side constraints needed. for (std::vector::iterator it = builder->sideConstraints.begin(), @@ -435,38 +767,17 @@ bool Z3SolverImpl::internalRunSolver( exprs.insert(sideConstraint); } - std::vector arrays = query.gatherArrays(); - bool forceTactic = true; - for (const Array *array : arrays) { - if (isa(array->source)) { - forceTactic = false; - break; - } - } + Z3_solver theSolver = initNativeZ3(query, probe, goal); - Z3_solver theSolver; - if (forceTactic && Z3_probe_apply(builder->ctx, probe, goal)) { - theSolver = Z3_mk_solver_for_logic( - builder->ctx, Z3_mk_string_symbol(builder->ctx, "QF_AUFBV")); - } else { - theSolver = Z3_mk_solver(builder->ctx); - } - Z3_solver_inc_ref(builder->ctx, theSolver); - Z3_solver_set_params(builder->ctx, theSolver, solverParameters); - - for (std::unordered_set::iterator it = exprs.begin(), - ie = exprs.end(); - it != ie; ++it) { + for (auto it = exprs.cbegin(), ie = exprs.cend(); it != ie; ++it) { Z3ASTHandle expr = *it; - if (expr_to_track.count(expr)) { + if (env.expr_to_track.count(expr)) { Z3_solver_assert_and_track(builder->ctx, theSolver, expr, - expr_to_track[expr]); + env.expr_to_track[expr]); } else { Z3_solver_assert(builder->ctx, theSolver, expr); } } - Z3_solver_assert(builder->ctx, theSolver, z3NotQueryExpr); if (dumpedQueriesFile) { *dumpedQueriesFile << "; start Z3 query\n"; @@ -479,22 +790,19 @@ bool Z3SolverImpl::internalRunSolver( dumpedQueriesFile->flush(); } - constraints_ty allConstraints = query.constraints.cs(); - allConstraints.insert(query.expr); - std::unordered_map usedArrayBytes; - for (auto constraint : allConstraints) { + for (auto constraint : query.constraints.v) { std::vector> reads; findReads(constraint, true, reads); for (auto readExpr : reads) { const Array *readFromArray = readExpr->updates.root; assert(readFromArray); - usedArrayBytes[readFromArray].insert(readExpr->index); + env.usedArrayBytes[readFromArray].insert(readExpr->index); } } ::Z3_lbool satisfiable = Z3_solver_check(builder->ctx, theSolver); - runStatusCode = handleSolverResponse(theSolver, satisfiable, objects, values, - usedArrayBytes, hasSolution); + runStatusCode = + handleSolverResponse(theSolver, satisfiable, env, values, hasSolution); if (ProduceUnsatCore && validityCore && satisfiable == Z3_L_FALSE) { constraints_ty unsatCore; Z3_ast_vector z3_unsat_core = @@ -511,15 +819,15 @@ bool Z3SolverImpl::internalRunSolver( z3_ast_expr_unsat_core.insert(constraint); } - for (auto &z3_constraint : z3_ast_expr_constraints) { + for (const auto &z3_constraint : env.z3_ast_expr_constraints.v) { if (z3_ast_expr_unsat_core.find(z3_constraint) != z3_ast_expr_unsat_core.end()) { - ref constraint = z3_ast_expr_to_klee_expr[z3_constraint]; + ref constraint = env.z3_ast_expr_to_klee_expr[z3_constraint]; unsatCore.insert(constraint); } } assert(validityCore && "validityCore cannot be nullptr"); - *validityCore = ValidityCore(unsatCore, query.expr); + *validityCore = ValidityCore(unsatCore, query.getOriginalQueryExpr()); Z3_ast_vector assertions = Z3_solver_get_assertions(builder->ctx, theSolver); @@ -536,7 +844,7 @@ bool Z3SolverImpl::internalRunSolver( Z3_goal_dec_ref(builder->ctx, goal); Z3_probe_dec_ref(builder->ctx, probe); - Z3_solver_dec_ref(builder->ctx, theSolver); + deinitNativeZ3(theSolver); // Clear the builder's cache to prevent memory usage exploding. // By using ``autoClearConstructCache=false`` and clearning now @@ -561,28 +869,22 @@ bool Z3SolverImpl::internalRunSolver( } SolverImpl::SolverRunStatus Z3SolverImpl::handleSolverResponse( - ::Z3_solver theSolver, ::Z3_lbool satisfiable, - const std::vector *objects, - std::vector> *values, - const std::unordered_map &usedArrayBytes, - bool &hasSolution) { + ::Z3_solver theSolver, ::Z3_lbool satisfiable, const Z3SolverEnv &env, + std::vector> *values, bool &hasSolution) { switch (satisfiable) { case Z3_L_TRUE: { hasSolution = true; - if (!objects) { + if (env.objects.v.empty()) { // No assignment is needed - assert(values == NULL); + assert(!values); return SolverImpl::SOLVER_RUN_STATUS_SUCCESS_SOLVABLE; } assert(values && "values cannot be nullptr"); ::Z3_model theModel = Z3_solver_get_model(builder->ctx, theSolver); assert(theModel && "Failed to retrieve model"); Z3_model_inc_ref(builder->ctx, theModel); - values->reserve(objects->size()); - for (std::vector::const_iterator it = objects->begin(), - ie = objects->end(); - it != ie; ++it) { - const Array *array = *it; + values->reserve(env.objects.v.size()); + for (auto array : env.objects.v) { SparseStorage data; ::Z3_ast arraySizeExpr; @@ -599,9 +901,9 @@ SolverImpl::SolverRunStatus Z3SolverImpl::handleSolverResponse( assert(success && "Failed to get size"); data.resize(arraySize); - if (usedArrayBytes.count(array)) { + if (env.usedArrayBytes.count(array)) { std::unordered_set offsetValues; - for (ref offsetExpr : usedArrayBytes.at(array)) { + for (const ref &offsetExpr : env.usedArrayBytes.at(array)) { ::Z3_ast arrayElementOffsetExpr; Z3_model_eval(builder->ctx, theModel, builder->construct(offsetExpr), Z3_TRUE, &arrayElementOffsetExpr); @@ -737,5 +1039,278 @@ bool Z3SolverImpl::validateZ3Model(::Z3_solver &theSolver, SolverImpl::SolverRunStatus Z3SolverImpl::getOperationStatusCode() { return runStatusCode; } + +struct ConstraintDistance { + size_t toPopSize = 0; + ConstraintQuery toPush; + + explicit ConstraintDistance() {} + ConstraintDistance(const ConstraintQuery &q) : toPush(q) {} + explicit ConstraintDistance(size_t toPopSize, const ConstraintQuery &q) + : toPopSize(toPopSize), toPush(q) {} + + size_t getDistance() const { return toPopSize + toPush.size(); } + + bool isOnlyPush() const { return toPopSize == 0; } + + void dump() const { + llvm::errs() << "ConstraintDistance: pop: " << toPopSize << "; push:\n"; + klee::dump(toPush.constraints); + } +}; + +class Z3IncNativeSolver { +private: + Z3_solver nativeSolver = nullptr; + Z3_context ctx; + Z3_params solverParameters; + /// underlying solver frames + /// saved only for calculating distances from next queries + ConstraintFrames frames; + + void pop(size_t popSize); + void push(); + +public: + Z3SolverEnv env; + + Z3IncNativeSolver(Z3_context ctx, Z3_params solverParameters) + : ctx(ctx), solverParameters(solverParameters) {} + ~Z3IncNativeSolver(); + + void distance(const ConstraintQuery &query, ConstraintDistance &delta) const; + + void popPush(ConstraintDistance &delta); + + Z3_solver getOrInit(const ConstraintQuery &query, Z3_probe probe, + Z3_goal goal); + + void dump() const { ::klee::dump(frames); } +}; + +void Z3IncNativeSolver::pop(size_t popSize) { + if (nativeSolver == nullptr) + return; + Z3_solver_pop(ctx, nativeSolver, popSize); +} + +void Z3IncNativeSolver::push() { + if (nativeSolver == nullptr) + return; + Z3_solver_push(ctx, nativeSolver); +} + +void Z3IncNativeSolver::popPush(ConstraintDistance &delta) { + env.push(); + env.pop(delta.toPopSize); + pop(delta.toPopSize); + push(); + frames.pop(delta.toPopSize); + frames.extend(delta.toPush.constraints.v); + assert(env.objects.end() - env.objects.begin() == + frames.end() - frames.begin()); +} + +Z3_solver Z3IncNativeSolver::getOrInit(const ConstraintQuery &query, + Z3_probe probe, Z3_goal goal) { + if (nativeSolver == nullptr) { + nativeSolver = createNativeZ3(ctx, solverParameters, query, probe, goal); + push(); + } + return nativeSolver; +} + +Z3IncNativeSolver::~Z3IncNativeSolver() { + if (nativeSolver != nullptr) + deleteNativeZ3(ctx, nativeSolver); +} + +void Z3IncNativeSolver::distance(const ConstraintQuery &query, + ConstraintDistance &delta) const { + auto sit = frames.v.begin(); + auto site = frames.v.end(); + auto qit = query.constraints.v.begin(); + auto qite = query.constraints.v.end(); + auto it = frames.begin(); + auto ite = frames.end(); + size_t intersect = 0; + for (; it != ite && sit != site && qit != qite && *sit == *qit; it++) { + size_t frame_size = *it; + for (size_t i = 0; + i < frame_size && sit != site && qit != qite && *sit == *qit; + i++, sit++, qit++, intersect++) { + } + } + for (; sit != site && qit != qite && *sit == *qit; + sit++, qit++, intersect++) { + } + size_t toPop, extraTakeFromOther; + ConstraintFrames d; + if (sit == site) { // solver frames ended + toPop = 0; + extraTakeFromOther = 0; + } else { + frames.takeBefore(intersect, toPop, extraTakeFromOther); + } + query.constraints.takeAfter(intersect - extraTakeFromOther, d); + delta = ConstraintDistance(toPop, d); +} + +class Z3TreeSolverImpl : public Z3SolverImpl { +private: + unsigned maxSolvers; + Z3IncNativeSolver *currentSolver = nullptr; + + std::deque recentlyUsed; + std::vector recycledSolvers; + + void reuseOrCreateZ3(const ConstraintQuery &query, ConstraintDistance &delta); + void findSuitableSolver(const ConstraintQuery &query, + ConstraintDistance &delta); + + using Z3SolverImpl::check; + bool check(ConstraintDistance &delta, ref &result); + +public: + Z3TreeSolverImpl(Z3BuilderType type, unsigned maxSolvers) + : Z3SolverImpl(type), maxSolvers(maxSolvers){}; + ~Z3TreeSolverImpl(); + + /// implementation of Z3SolverImpl interface + Z3_solver initNativeZ3(const ConstraintQuery &query, Z3_probe probe, + Z3_goal goal); + void deinitNativeZ3(Z3_solver theSolver); + + void notifyStateTermination(std::uint32_t id); + + /// implementation of the SolverImpl interface + bool computeTruth(const Query &query, bool &isValid); + bool computeValidity(const Query &query, PartialValidity &result); + bool computeInitialValues(const Query &query, + const std::vector &objects, + std::vector> &values, + bool &hasSolution); + bool check(const Query &query, ref &result); + bool computeValidityCore(const Query &query, ValidityCore &validityCore, + bool &isValid); + SolverRunStatus getOperationStatusCode(); +}; + +Z3TreeSolverImpl::~Z3TreeSolverImpl() { + for (auto solver : recentlyUsed) + delete solver; + for (auto solver : recycledSolvers) + delete solver; +} + +Z3_solver Z3TreeSolverImpl::initNativeZ3(const ConstraintQuery &query, + Z3_probe probe, Z3_goal goal) { + return currentSolver->getOrInit(query, probe, goal); +} + +void Z3TreeSolverImpl::deinitNativeZ3(Z3_solver theSolver) { + recentlyUsed.push_front(currentSolver); + currentSolver = nullptr; +} + +bool Z3TreeSolverImpl::check(ConstraintDistance &delta, + ref &result) { + currentSolver->popPush(delta); + return check(delta.toPush, currentSolver->env, result); +} + +void Z3TreeSolverImpl::reuseOrCreateZ3(const ConstraintQuery &query, + ConstraintDistance &delta) { + auto min_delta = ConstraintDistance(query); + auto min_distance = min_delta.getDistance(); + Z3IncNativeSolver *min_solver = nullptr; + for (auto solver : recycledSolvers) { + ConstraintDistance d; + solver->distance(query, d); + auto distance = d.getDistance(); + if (distance < min_distance) { + min_delta = d; + min_distance = distance; + min_solver = solver; + } + } + currentSolver = min_solver + ? min_solver + : new Z3IncNativeSolver(builder->ctx, solverParameters); + delta = min_delta; +} + +void Z3TreeSolverImpl::findSuitableSolver( + const ConstraintQuery &query, + ConstraintDistance &delta) { + for (auto it = recentlyUsed.begin(), + ite = recentlyUsed.end(); + it != ite; it++) { + currentSolver = *it; + ConstraintDistance d; + currentSolver->distance(query, d); + if (d.isOnlyPush()) { + recentlyUsed.erase(it); + delta = d; + return; + } + } + if (recentlyUsed.size() < maxSolvers) { + reuseOrCreateZ3(query, delta); + return; + } + currentSolver = recentlyUsed.back(); + recentlyUsed.pop_back(); + currentSolver->distance(query, delta); +} + +bool Z3TreeSolverImpl::computeTruth(const Query &query, bool &isValid) { + assert(false); + return false; // TODO: not implemented +} + +bool Z3TreeSolverImpl::computeValidity(const Query &query, + PartialValidity &result) { + assert(false); + return false; // TODO: not implemented +} + +bool Z3TreeSolverImpl::computeInitialValues( + const Query &query, const std::vector &objects, + std::vector> &values, bool &hasSolution) { + llvm::errs() << "Z3TreeSolverImpl::computeInitialValues:\n"; + query.dump(); + assert(false); + return false; // TODO: not implemented +} + +bool Z3TreeSolverImpl::check(const Query &q, ref &result) { + ConstraintDistance delta; + ConstraintQuery query(q); + findSuitableSolver(query, delta); + auto ok = check(delta, result); + return ok; +} + +bool Z3TreeSolverImpl::computeValidityCore(const Query &query, + ValidityCore &validityCore, + bool &isValid) { + assert(false); + return false; // TODO: not implemented +} + +Z3TreeSolverImpl::SolverRunStatus Z3TreeSolverImpl::getOperationStatusCode() { + assert(false); + return SOLVER_RUN_STATUS_TIMEOUT; // TODO: not implemented +} + +void Z3TreeSolverImpl::notifyStateTermination(std::uint32_t id) { + assert(false); + return; // TODO: not implemented +} + +Z3TreeSolver::Z3TreeSolver(Z3BuilderType type, unsigned maxSolvers) + : Solver(std::make_unique(type, maxSolvers)) {} + } // namespace klee #endif // ENABLE_Z3 diff --git a/lib/Solver/Z3Solver.h b/lib/Solver/Z3Solver.h index 0189dec08f1..0e6649a6d10 100644 --- a/lib/Solver/Z3Solver.h +++ b/lib/Solver/Z3Solver.h @@ -33,6 +33,12 @@ class Z3Solver : public Solver { /// is off. virtual void setCoreSolverTimeout(time::Span timeout); }; + +class Z3TreeSolver : public Solver { +public: + Z3TreeSolver(Z3BuilderType type, unsigned maxSolvers); +}; + } // namespace klee #endif /* KLEE_Z3SOLVER_H */