Skip to content

Commit

Permalink
[feat] Z3 Tree incremental solver
Browse files Browse the repository at this point in the history
  • Loading branch information
Columpio committed Aug 8, 2023
1 parent 479d04f commit a0c8bc0
Show file tree
Hide file tree
Showing 18 changed files with 804 additions and 149 deletions.
35 changes: 35 additions & 0 deletions docs/tree_incrementality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
class TreeIncrementalSolver:
def __init__(self, Solver, max_solvers: int):
self.Solver = Solver
self.max_solvers = max_solvers
self.recently_used = Queue()
self.recycled_solvers = []

def reuseOrCreateZ3(self, query):
solver_from_scratch_cost = len(query)
min_cost, min_solver = solver_from_scratch_cost, None
for solver in self.recycled_solvers:
delta = distance(solver.query, query)
if delta < min_cost:
min_cost, min_solver = delta, solver
if min_solver is None:
return self.Solver()
return min_solver

def find_suitable_solver(self, query):
for solver in self.recently_used:
if solver.query is subsetOf(query):
self.recently_used.remove(solver)
return solver
if len(self.recently_used) < self.max_solvers:
return self.reuseOrCreateZ3(query)
return self.recently_used.pop_back()

def check_sat(self, query):
solver = self.find_suitable_solver(query)
push(solver, query)
self.recently_used.push_front(solver)
solver.check_sat()

def recycle_solver(self, id):
raise NotImplementedError()
9 changes: 9 additions & 0 deletions include/klee/Solver/Solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ class Solver {

virtual char *getConstraintLog(const Query &query);
virtual void setCoreSolverTimeout(time::Span timeout);

/// @brief Notify the solver that the state with specified id has been
/// terminated
void notifyStateTermination(std::uint32_t id);
};

/* *** */
Expand Down Expand Up @@ -264,6 +268,11 @@ std::unique_ptr<Solver> createCoreSolver(CoreSolverType cst);
std::unique_ptr<Solver>
createConcretizingSolver(std::unique_ptr<Solver> s,
AddressGenerator *addressGenerator);

/// Return a list of all unique symbolic objects referenced by the
/// given Query.
void findSymbolicObjects(const Query &query,
std::vector<const Array *> &results);
} // namespace klee

#endif /* KLEE_SOLVER_H */
3 changes: 3 additions & 0 deletions include/klee/Solver/SolverCmdLine.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ extern llvm::cl::opt<bool> CoreSolverOptimizeDivides;

extern llvm::cl::opt<bool> UseAssignmentValidatingSolver;

extern llvm::cl::opt<unsigned> MaxSolversApproxTreeInc;

/// The different query logging solvers that can be switched on/off
enum QueryLoggingSolverType {
ALL_KQUERY, ///< Log all queries in .kquery (KQuery) format
Expand All @@ -65,6 +67,7 @@ enum CoreSolverType {
METASMT_SOLVER,
DUMMY_SOLVER,
Z3_SOLVER,
Z3_TREE_SOLVER,
NO_SOLVER
};

Expand Down
4 changes: 3 additions & 1 deletion include/klee/Solver/SolverImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExecutionState;
class Expr;
struct Query;

/// SolverImpl - Abstract base clase for solver implementations.
/// SolverImpl - Abstract base class for solver implementations.
class SolverImpl {
public:
SolverImpl() = default;
Expand Down Expand Up @@ -119,6 +119,8 @@ class SolverImpl {
}

virtual void setCoreSolverTimeout(time::Span timeout){};

virtual void notifyStateTermination(std::uint32_t id){};
};

} // namespace klee
Expand Down
3 changes: 3 additions & 0 deletions include/klee/Solver/SolverUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ enum class Validity { True = 1, False = -1, Unknown = 0 };
struct SolverQueryMetaData {
/// @brief Costs for all queries issued for this state
time::Span queryCost;

/// @brief Caller state id
std::uint32_t id = 0;
};

struct Query {
Expand Down
4 changes: 3 additions & 1 deletion lib/Core/ExecutionState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ ExecutionState::ExecutionState(const ExecutionState &state)
returnValue(state.returnValue), gepExprBases(state.gepExprBases),
prevTargets_(state.prevTargets_), targets_(state.targets_),
prevHistory_(state.prevHistory_), history_(state.history_),
isTargeted_(state.isTargeted_) {}
isTargeted_(state.isTargeted_) {
queryMetaData.id = state.id;
}

ExecutionState *ExecutionState::branch() {
depth++;
Expand Down
5 changes: 4 additions & 1 deletion lib/Core/ExecutionState.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ class ExecutionState {
bool visited(KBlock *block) const;

std::uint32_t getID() const { return id; };
void setID() { id = nextID++; };
void setID() {
id = nextID++;
queryMetaData.id = id;
};
llvm::BasicBlock *getInitPCBlock() const;
llvm::BasicBlock *getPrevPCBlock() const;
llvm::BasicBlock *getPCBlock() const;
Expand Down
4 changes: 4 additions & 0 deletions lib/Core/TimingSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,7 @@ TimingSolver::getRange(const ConstraintSet &constraints, ref<Expr> expr,
metaData.queryCost += timer.delta();
return result;
}

void TimingSolver::notifyStateTermination(std::uint32_t id) {
solver->notifyStateTermination(id);
}
4 changes: 4 additions & 0 deletions lib/Core/TimingSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class TimingSolver {
return solver->getConstraintLog(query);
}

/// @brief Notify the solver that the state with specified id has been
/// terminated
void notifyStateTermination(std::uint32_t id);

bool evaluate(const ConstraintSet &, ref<Expr>, PartialValidity &result,
SolverQueryMetaData &metaData,
bool produceValidityCore = false);
Expand Down
8 changes: 1 addition & 7 deletions lib/Solver/AssignmentValidatingSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,8 @@ bool AssignmentValidatingSolver::check(const Query &query,
return true;
}

ExprHashSet expressions;
assert(!query.containsSymcretes());
expressions.insert(query.constraints.cs().begin(),
query.constraints.cs().end());
expressions.insert(query.expr);

std::vector<const Array *> objects;
findSymbolicObjects(expressions.begin(), expressions.end(), objects);
findSymbolicObjects(query, objects);
std::vector<SparseStorage<unsigned char>> values;

assert(isa<InvalidResponse>(result));
Expand Down
23 changes: 9 additions & 14 deletions lib/Solver/CexCachingSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,9 @@ bool CexCachingSolver::computeValidity(const Query &query,
PartialValidity &result) {
TimerStatIncrementer t(stats::cexCacheTime);
ref<SolverResponse> a;
if (!getResponse(query.withFalse(), a))
ref<Expr> q;
if (!computeValue(query, q))
return false;
assert(isa<InvalidResponse>(a) && "computeValidity() must have assignment");

ref<Expr> q = cast<InvalidResponse>(a)->evaluate(query.expr);

if (!isa<ConstantExpr>(q) && solver->impl->computeValue(query, q))
return false;

assert(isa<ConstantExpr>(q) &&
"assignment evaluation did not result in constant");

if (cast<ConstantExpr>(q)->isTrue()) {
if (!getResponse(query, a))
Expand Down Expand Up @@ -343,10 +335,13 @@ bool CexCachingSolver::computeValue(const Query &query, ref<Expr> &result) {
TimerStatIncrementer t(stats::cexCacheTime);

ref<SolverResponse> a;
if (!getResponse(query.withFalse(), a))
return false;
assert(isa<InvalidResponse>(a) && "computeValue() must have assignment");
result = cast<InvalidResponse>(a)->evaluate(query.expr);
result = ConstantExpr::create(1, query.expr->getWidth());
if (!query.constraints.cs().empty()) {
if (!getResponse(query.withFalse(), a))
return false;
assert(isa<InvalidResponse>(a) && "computeValue() must have assignment");
result = cast<InvalidResponse>(a)->evaluate(query.expr);
}

if (!isa<ConstantExpr>(result) && solver->impl->computeValue(query, result))
return false;
Expand Down
13 changes: 11 additions & 2 deletions lib/Solver/CoreSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ DISABLE_WARNING_POP

namespace klee {

std::unique_ptr<Solver> createZ3Solver(bool isTreeSolver, Z3BuilderType type) {
if (isTreeSolver)
return std::make_unique<Z3TreeSolver>(type, MaxSolversApproxTreeInc);
return std::make_unique<Z3Solver>(type);
}

std::unique_ptr<Solver> createCoreSolver(CoreSolverType cst) {
bool isTreeSolver = false;
switch (cst) {
case STP_SOLVER:
#ifdef ENABLE_STP
Expand All @@ -54,15 +61,17 @@ std::unique_ptr<Solver> createCoreSolver(CoreSolverType cst) {
#endif
case DUMMY_SOLVER:
return createDummySolver();
case Z3_TREE_SOLVER:
isTreeSolver = true;
case Z3_SOLVER:
#ifdef ENABLE_Z3
klee_message("Using Z3 solver backend");
#ifdef ENABLE_FP
klee_message("Using Z3 bitvector builder");
return std::make_unique<Z3Solver>(KLEE_BITVECTOR);
return createZ3Solver(isTreeSolver, KLEE_BITVECTOR);
#else
klee_message("Using Z3 core builder");
return std::make_unique<Z3Solver>(KLEE_CORE);
return createZ3Solver(isTreeSolver, KLEE_CORE);
#endif
#else
klee_message("Not compiled with Z3 support");
Expand Down
7 changes: 1 addition & 6 deletions lib/Solver/IncompleteSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,8 @@ bool StagedSolverImpl::computeInitialValues(
}

bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
ExprHashSet expressions;
expressions.insert(query.constraints.cs().begin(),
query.constraints.cs().end());
expressions.insert(query.expr);

std::vector<const Array *> objects;
findSymbolicObjects(expressions.begin(), expressions.end(), objects);
findSymbolicObjects(query, objects);
std::vector<SparseStorage<unsigned char>> values;

bool hasSolution;
Expand Down
13 changes: 13 additions & 0 deletions lib/Solver/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ bool Solver::check(const Query &query, ref<SolverResponse> &queryResult) {
return impl->check(query, queryResult);
}

void Solver::notifyStateTermination(std::uint32_t id) {
impl->notifyStateTermination(id);
}

static std::pair<ref<ConstantExpr>, ref<ConstantExpr>> getDefaultRange() {
return std::make_pair(ConstantExpr::create(0, 64),
ConstantExpr::create(0, 64));
Expand Down Expand Up @@ -327,6 +331,15 @@ bool Query::containsSizeSymcretes() const {
return false;
}

void klee::findSymbolicObjects(const Query &query,
std::vector<const Array *> &results) {
ExprHashSet expressions;
expressions.insert(query.constraints.cs().begin(),
query.constraints.cs().end());
expressions.insert(query.expr);
findSymbolicObjects(expressions.begin(), expressions.end(), results);
}

void Query::dump() const {
constraints.dump();
llvm::errs() << "Query [\n";
Expand Down
18 changes: 13 additions & 5 deletions lib/Solver/SolverCmdLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ cl::opt<bool> UseAssignmentValidatingSolver(
cl::desc("Debug the correctness of generated assignments (default=false)"),
cl::cat(SolvingCat));

cl::opt<unsigned>
MaxSolversApproxTreeInc("max-solvers-approx-tree-inc",
cl::desc("Maximum size of the Z3 solver pool for "
"approximating tree incrementality."
" Set to 0 to disable (default=0)"),
cl::init(0), cl::cat(SolvingCat));

void KCommandLine::HideOptions(llvm::cl::OptionCategory &Category) {
StringMap<cl::Option *> &map = cl::getRegisteredOptions();

Expand Down Expand Up @@ -196,11 +203,12 @@ cl::opt<klee::MetaSMTBackendType> MetaSMTBackend(

cl::opt<CoreSolverType> CoreSolverToUse(
"solver-backend", cl::desc("Specifiy the core solver backend to use"),
cl::values(clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR),
clEnumValN(METASMT_SOLVER, "metasmt",
"metaSMT" METASMT_IS_DEFAULT_STR),
clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"),
clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR)),
cl::values(
clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR),
clEnumValN(METASMT_SOLVER, "metasmt", "metaSMT" METASMT_IS_DEFAULT_STR),
clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"),
clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR),
clEnumValN(Z3_TREE_SOLVER, "z3-tree", "Z3 tree-incremental solver")),
cl::init(DEFAULT_CORE_SOLVER), cl::cat(SolvingCat));

cl::opt<CoreSolverType> DebugCrossCheckCoreSolverWith(
Expand Down
7 changes: 1 addition & 6 deletions lib/Solver/SolverImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,8 @@ bool SolverImpl::computeValidity(const Query &query,
}

bool SolverImpl::check(const Query &query, ref<SolverResponse> &result) {
ExprHashSet expressions;
expressions.insert(query.constraints.cs().begin(),
query.constraints.cs().end());
expressions.insert(query.expr);

std::vector<const Array *> objects;
findSymbolicObjects(expressions.begin(), expressions.end(), objects);
findSymbolicObjects(query, objects);
std::vector<SparseStorage<unsigned char>> values;

bool hasSolution;
Expand Down
Loading

0 comments on commit a0c8bc0

Please sign in to comment.