From 3ac7f672bac4eb252daa40ab9cd29f1fca08de5c Mon Sep 17 00:00:00 2001 From: Aleksandr Misonizhnik Date: Mon, 23 Oct 2023 15:01:01 +0400 Subject: [PATCH 1/5] [feat] Use `FastCexSolver` in restricted cases that are are fairly easy to solve --- include/klee/ADT/DisjointSetUnion.h | 21 ++++ include/klee/Solver/IncompleteSolver.h | 6 +- lib/ADT/SparseStorage.cpp | 2 +- lib/Expr/IndependentConstraintSetUnion.cpp | 18 +--- lib/Solver/FastCexSolver.cpp | 84 ++++++++++++--- lib/Solver/IncompleteSolver.cpp | 117 ++++++++++++--------- lib/Solver/SolverCmdLine.cpp | 4 +- test/Feature/DanglingConcreteReadExpr.c | 5 +- test/Solver/FastCexSolver.kquery | 2 +- 9 files changed, 172 insertions(+), 87 deletions(-) diff --git a/include/klee/ADT/DisjointSetUnion.h b/include/klee/ADT/DisjointSetUnion.h index b479ca0765..746fdc38b1 100644 --- a/include/klee/ADT/DisjointSetUnion.h +++ b/include/klee/ADT/DisjointSetUnion.h @@ -150,6 +150,27 @@ class DisjointSetUnion { } } + void getAllDependentSets(ValueType value, + std::vector> &result) const { + ref compare = new SetType(value); + for (auto &r : roots) { + ref ics = disjointSets.at(r); + if (SetType::intersects(ics, compare)) { + result.push_back(ics); + } + } + } + void getAllIndependentSets(ValueType value, + std::vector> &result) const { + ref compare = new SetType(value); + for (auto &r : roots) { + ref ics = disjointSets.at(r); + if (!SetType::intersects(ics, compare)) { + result.push_back(ics); + } + } + } + DisjointSetUnion() {} DisjointSetUnion(const internalStorage_ty &is) { diff --git a/include/klee/Solver/IncompleteSolver.h b/include/klee/Solver/IncompleteSolver.h index 65dac30c4b..777e6722c4 100644 --- a/include/klee/Solver/IncompleteSolver.h +++ b/include/klee/Solver/IncompleteSolver.h @@ -58,14 +58,18 @@ class IncompleteSolver { /// StagedSolver - Adapter class for staging an incomplete solver with /// a complete secondary solver, to form an (optimized) complete /// solver. + +typedef std::function QueryPredicate; + class StagedSolverImpl : public SolverImpl { private: std::unique_ptr primary; std::unique_ptr secondary; + QueryPredicate predicate; public: StagedSolverImpl(std::unique_ptr primary, - std::unique_ptr secondary); + std::unique_ptr secondary, QueryPredicate predicate); bool computeTruth(const Query &, bool &isValid); bool computeValidity(const Query &, PartialValidity &result); diff --git a/lib/ADT/SparseStorage.cpp b/lib/ADT/SparseStorage.cpp index dc24e7d993..935cb6922a 100644 --- a/lib/ADT/SparseStorage.cpp +++ b/lib/ADT/SparseStorage.cpp @@ -35,7 +35,7 @@ void SparseStorage::print(llvm::raw_ostream &os, } os << "] default: "; } - os << defaultValue; + os << ((unsigned)defaultValue); } template <> diff --git a/lib/Expr/IndependentConstraintSetUnion.cpp b/lib/Expr/IndependentConstraintSetUnion.cpp index 01358e2623..f4ecf876de 100644 --- a/lib/Expr/IndependentConstraintSetUnion.cpp +++ b/lib/Expr/IndependentConstraintSetUnion.cpp @@ -95,27 +95,13 @@ void IndependentConstraintSetUnion::reEvaluateConcretization( void IndependentConstraintSetUnion::getAllIndependentConstraintSets( ref e, std::vector> &result) const { - ref compare = - new IndependentConstraintSet(new ExprEitherSymcrete::left(e)); - for (auto &r : roots) { - ref ics = disjointSets.at(r); - if (!IndependentConstraintSet::intersects(ics, compare)) { - result.push_back(ics); - } - } + getAllIndependentSets(new ExprEitherSymcrete::left(e), result); } void IndependentConstraintSetUnion::getAllDependentConstraintSets( ref e, std::vector> &result) const { - ref compare = - new IndependentConstraintSet(new ExprEitherSymcrete::left(e)); - for (auto &r : roots) { - ref ics = disjointSets.at(r); - if (IndependentConstraintSet::intersects(ics, compare)) { - result.push_back(ics); - } - } + getAllDependentSets(new ExprEitherSymcrete::left(e), result); } void IndependentConstraintSetUnion::addExpr(ref e) { diff --git a/lib/Solver/FastCexSolver.cpp b/lib/Solver/FastCexSolver.cpp index f4c1b377aa..a878bc34bf 100644 --- a/lib/Solver/FastCexSolver.cpp +++ b/lib/Solver/FastCexSolver.cpp @@ -18,6 +18,7 @@ #include "klee/Solver/IncompleteSolver.h" #include "klee/Support/Debug.h" #include "klee/Support/ErrorHandling.h" +#include "klee/Support/OptionCategories.h" #include "klee/Support/CompilerWarning.h" DISABLE_WARNING_PUSH @@ -33,6 +34,20 @@ DISABLE_WARNING_POP #include using namespace klee; +using namespace llvm; + +namespace { +enum class FastCexSolverType { EQUALITY, ALL }; + +cl::opt FastCexFor( + "fast-cex-for", + cl::desc( + "Specifiy a query predicate to filter queries for FastCexSolver using"), + cl::values(clEnumValN(FastCexSolverType::EQUALITY, "equality", + "Query with only equality expressions"), + clEnumValN(FastCexSolverType::ALL, "all", "All queries")), + cl::init(FastCexSolverType::EQUALITY), cl::cat(SolvingCat)); +} // namespace // Hacker's Delight, pgs 58-63 static uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { @@ -403,10 +418,11 @@ class CexPossibleEvaluator : public ExprEvaluator { ref getInitialValue(const Array &array, unsigned index) { // If the index is out of range, we cannot assign it a value, since that // value cannot be part of the assignment. - ref constantArraySize = dyn_cast(array.size); + ref constantArraySize = + dyn_cast(visit(array.size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); + klee_error("FIXME: CexPossibleEvaluator: Arrays of symbolic sizes are " + "unsupported in FastCex\n"); std::abort(); } @@ -433,11 +449,11 @@ class CexExactEvaluator : public ExprEvaluator { ref getInitialValue(const Array &array, unsigned index) { // If the index is out of range, we cannot assign it a value, since that // value cannot be part of the assignment. - ref constantArraySize = dyn_cast(array.size); + ref constantArraySize = + dyn_cast(visit(array.size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); - std::abort(); + return ReadExpr::create(UpdateList(&array, 0), + ConstantExpr::alloc(index, array.getDomain())); } if (index >= constantArraySize->getZExtValue()) { @@ -485,10 +501,11 @@ class CexData { CexObjectData &getObjectData(const Array *A) { CexObjectData *&Entry = objects[A]; - ref constantArraySize = dyn_cast(A->size); + ref constantArraySize = + dyn_cast(evaluatePossible(A->size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); + klee_error("FIXME: CexData: Arrays of symbolic sizes are unsupported in " + "FastCex\n"); std::abort(); } @@ -529,7 +546,7 @@ class CexData { // to see if this is an initial read or not. if (ConstantExpr *CE = dyn_cast(re->index)) { if (ref constantArraySize = - dyn_cast(array->size)) { + dyn_cast(evaluatePossible(array->size))) { uint64_t index = CE->getZExtValue(); if (index < constantArraySize->getZExtValue()) { @@ -1171,6 +1188,7 @@ bool FastCexSolver::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { CexData cd; + query.dump(); bool isValid; bool success = propagateValues(query, cd, true, isValid); @@ -1187,7 +1205,7 @@ bool FastCexSolver::computeInitialValues( for (unsigned i = 0; i != objects.size(); ++i) { const Array *array = objects[i]; assert(array); - SparseStorage data; + SparseStorage data(0); ref arrayConstantSize = dyn_cast(cd.evaluatePossible(array->size)); assert(arrayConstantSize && @@ -1212,7 +1230,45 @@ bool FastCexSolver::computeInitialValues( return true; } +class OnlyEqualityWithConstantQueryPredicate { +public: + explicit OnlyEqualityWithConstantQueryPredicate() {} + + bool operator()(const Query &query) const { + for (auto constraint : query.constraints.cs()) { + if (const EqExpr *ee = dyn_cast(constraint)) { + if (!isa(ee->left)) { + return false; + } + } else { + return false; + } + } + if (ref ee = dyn_cast(query.negateExpr().expr)) { + if (!isa(ee->left)) { + return false; + } + } else { + return false; + } + return true; + } +}; + +class TrueQueryPredicate { +public: + explicit TrueQueryPredicate() {} + + bool operator()(const Query &query) const { return true; } +}; + std::unique_ptr klee::createFastCexSolver(std::unique_ptr s) { - return std::make_unique(std::make_unique( - std::make_unique(), std::move(s))); + if (FastCexFor == FastCexSolverType::EQUALITY) { + return std::make_unique(std::make_unique( + std::make_unique(), std::move(s), + OnlyEqualityWithConstantQueryPredicate())); + } else { + return std::make_unique(std::make_unique( + std::make_unique(), std::move(s), TrueQueryPredicate())); + } } diff --git a/lib/Solver/IncompleteSolver.cpp b/lib/Solver/IncompleteSolver.cpp index 85ad5a8d6d..10436c4f63 100644 --- a/lib/Solver/IncompleteSolver.cpp +++ b/lib/Solver/IncompleteSolver.cpp @@ -49,15 +49,19 @@ PartialValidity IncompleteSolver::computeValidity(const Query &query) { /***/ StagedSolverImpl::StagedSolverImpl(std::unique_ptr primary, - std::unique_ptr secondary) - : primary(std::move(primary)), secondary(std::move(secondary)) {} + std::unique_ptr secondary, + QueryPredicate predicate) + : primary(std::move(primary)), secondary(std::move(secondary)), + predicate(predicate) {} bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) { - PartialValidity trueResult = primary->computeTruth(query); + if (predicate(query)) { + PartialValidity trueResult = primary->computeTruth(query); - if (trueResult != PValidity::None) { - isValid = (trueResult == PValidity::MustBeTrue); - return true; + if (trueResult != PValidity::None) { + isValid = (trueResult == PValidity::MustBeTrue); + return true; + } } return secondary->impl->computeTruth(query, isValid); @@ -65,44 +69,48 @@ bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) { bool StagedSolverImpl::computeValidity(const Query &query, PartialValidity &result) { - bool tmp; - - switch (primary->computeValidity(query)) { - case PValidity::MustBeTrue: - result = PValidity::MustBeTrue; - break; - case PValidity::MustBeFalse: - result = PValidity::MustBeFalse; - break; - case PValidity::TrueOrFalse: - result = PValidity::TrueOrFalse; - break; - case PValidity::MayBeTrue: - if (secondary->impl->computeTruth(query, tmp)) { - - result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse; - } else { - result = PValidity::MayBeTrue; - } - break; - case PValidity::MayBeFalse: - if (secondary->impl->computeTruth(query.negateExpr(), tmp)) { - result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse; - } else { - result = PValidity::MayBeFalse; + if (predicate(query)) { + bool tmp; + + switch (primary->computeValidity(query)) { + case PValidity::MustBeTrue: + result = PValidity::MustBeTrue; + break; + case PValidity::MustBeFalse: + result = PValidity::MustBeFalse; + break; + case PValidity::TrueOrFalse: + result = PValidity::TrueOrFalse; + break; + case PValidity::MayBeTrue: + if (secondary->impl->computeTruth(query, tmp)) { + + result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse; + } else { + result = PValidity::MayBeTrue; + } + break; + case PValidity::MayBeFalse: + if (secondary->impl->computeTruth(query.negateExpr(), tmp)) { + result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse; + } else { + result = PValidity::MayBeFalse; + } + break; + default: + if (!secondary->impl->computeValidity(query, result)) + return false; + break; } - break; - default: - if (!secondary->impl->computeValidity(query, result)) - return false; - break; + } else { + return secondary->impl->computeValidity(query, result); } return true; } bool StagedSolverImpl::computeValue(const Query &query, ref &result) { - if (primary->computeValue(query, result)) + if (predicate(query) && primary->computeValue(query, result)) return true; return secondary->impl->computeValue(query, result); @@ -111,7 +119,8 @@ bool StagedSolverImpl::computeValue(const Query &query, ref &result) { bool StagedSolverImpl::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { - if (primary->computeInitialValues(query, objects, values, hasSolution)) + if (predicate(query) && + primary->computeInitialValues(query, objects, values, hasSolution)) return true; return secondary->impl->computeInitialValues(query, objects, values, @@ -119,17 +128,19 @@ bool StagedSolverImpl::computeInitialValues( } bool StagedSolverImpl::check(const Query &query, ref &result) { - std::vector objects; - findSymbolicObjects(query, objects); - std::vector> values; - - bool hasSolution; - - bool primaryResult = - primary->computeInitialValues(query, objects, values, hasSolution); - if (primaryResult && hasSolution) { - result = new InvalidResponse(objects, values); - return true; + if (predicate(query)) { + std::vector objects; + findSymbolicObjects(query, objects); + std::vector> values; + + bool hasSolution; + + bool primaryResult = + primary->computeInitialValues(query, objects, values, hasSolution); + if (primaryResult && hasSolution) { + result = new InvalidResponse(objects, values); + return true; + } } return secondary->impl->check(query, result); @@ -138,6 +149,14 @@ bool StagedSolverImpl::check(const Query &query, ref &result) { bool StagedSolverImpl::computeValidityCore(const Query &query, ValidityCore &validityCore, bool &isValid) { + if (predicate(query)) { + PartialValidity trueResult = primary->computeTruth(query); + + if (trueResult == PValidity::MayBeFalse) { + isValid = false; + return true; + } + } return secondary->impl->computeValidityCore(query, validityCore, isValid); } diff --git a/lib/Solver/SolverCmdLine.cpp b/lib/Solver/SolverCmdLine.cpp index 0f51525d53..4e0c94dffe 100644 --- a/lib/Solver/SolverCmdLine.cpp +++ b/lib/Solver/SolverCmdLine.cpp @@ -43,8 +43,8 @@ cl::OptionCategory SolvingCat("Constraint solving options", "These options impact constraint solving."); cl::opt UseFastCexSolver( - "use-fast-cex-solver", cl::init(false), - cl::desc("Enable an experimental range-based solver (default=false)"), + "use-fast-cex-solver", cl::init(true), + cl::desc("Enable an experimental range-based solver (default=true)"), cl::cat(SolvingCat)); cl::opt diff --git a/test/Feature/DanglingConcreteReadExpr.c b/test/Feature/DanglingConcreteReadExpr.c index 1f8a5a347a..588072125c 100644 --- a/test/Feature/DanglingConcreteReadExpr.c +++ b/test/Feature/DanglingConcreteReadExpr.c @@ -1,7 +1,7 @@ // RUN: %clang %s -emit-llvm %O0opt -c -o %t1.bc // RUN: rm -rf %t.klee-out // RUN: %klee --optimize=false --output-dir=%t.klee-out %t1.bc -// RUN: grep "total queries = 1" %t.klee-out/info +// RUN: grep "total queries = 0" %t.klee-out/info #include @@ -12,8 +12,7 @@ int main() { y = x; - // should be exactly one query (prove x is 10) - // eventually should be 0 when we have fast solver + // should be exactly 0 query, finally we have enough optimizations if (x == 10) { assert(y == 10); } diff --git a/test/Solver/FastCexSolver.kquery b/test/Solver/FastCexSolver.kquery index 271609859b..8e82c84468 100644 --- a/test/Solver/FastCexSolver.kquery +++ b/test/Solver/FastCexSolver.kquery @@ -1,4 +1,4 @@ -# RUN: %kleaver --use-fast-cex-solver --solver-backend=dummy %s > %t +# RUN: %kleaver --use-fast-cex-solver --fast-cex-for=all --solver-backend=dummy %s > %t # RUN: not grep FAIL %t makeSymbolic0 : (array (w64 4) (makeSymbolic arr1 0)) From f2f80424cf614f58aaa76554cc3df746b9a1579a Mon Sep 17 00:00:00 2001 From: Aleksandr Misonizhnik Date: Tue, 24 Oct 2023 18:34:36 +0400 Subject: [PATCH 2/5] [fix] Generate test only for successful solution found --- lib/Core/Executor.cpp | 1 - tools/klee/main.cpp | 147 +++++++++++++++++++++--------------------- 2 files changed, 73 insertions(+), 75 deletions(-) diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp index dd34b18d5d..0fbbcfe1f8 100644 --- a/lib/Core/Executor.cpp +++ b/lib/Core/Executor.cpp @@ -7180,7 +7180,6 @@ bool Executor::getSymbolicSolution(const ExecutionState &state, KTest &res) { } bool success = solver->getInitialValues(extendedConstraints.cs(), objects, values, state.queryMetaData); - Assignment assignment(objects, values); solver->setTimeout(time::Span()); if (!success) { klee_warning("unable to compute initial values (invalid constraints?)!"); diff --git a/tools/klee/main.cpp b/tools/klee/main.cpp index 0e509a9c00..71e7f8623a 100644 --- a/tools/klee/main.cpp +++ b/tools/klee/main.cpp @@ -34,6 +34,7 @@ DISABLE_WARNING_DEPRECATED_DECLARATIONS #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" @@ -62,7 +63,6 @@ DISABLE_WARNING_POP #include #include #include -#include #include using json = nlohmann::json; @@ -601,9 +601,8 @@ void KleeHandler::processTestCase(const ExecutionState &state, const auto start_time = time::getWallTime(); bool atLeastOneGenerated = false; - if (WriteKTests) { - - if (success) { + if (success) { + if (WriteKTests) { for (unsigned i = 0; i < ktest.uninitCoeff + 1; ++i) { if (!kTest_toFile( &ktest, @@ -620,49 +619,36 @@ void KleeHandler::processTestCase(const ExecutionState &state, } } - if (message) { - auto f = openTestFile(suffix, id); - if (f) - *f << message; - } - - if (m_pathWriter) { - std::vector concreteBranches; - m_pathWriter->readStream(m_interpreter->getPathStreamID(state), - concreteBranches); - auto f = openTestFile("path", id); - if (f) { - for (const auto &branch : concreteBranches) { - *f << branch << '\n'; - } + if (WriteXMLTests) { + for (unsigned i = 0; i < ktest.uninitCoeff + 1; ++i) { + writeTestCaseXML(message != nullptr, ktest, id, i); + atLeastOneGenerated = true; } } - } - if (message || WriteKQueries) { - std::string constraints; - m_interpreter->getConstraintLog(state, constraints, Interpreter::KQUERY); - auto f = openTestFile("kquery", id); - if (f) - *f << constraints; + for (unsigned i = 0; i < ktest.numObjects; i++) { + delete[] ktest.objects[i].bytes; + delete[] ktest.objects[i].pointers; + } + delete[] ktest.objects; } - if (WriteCVCs) { - // FIXME: If using Z3 as the core solver the emitted file is actually - // SMT-LIBv2 not CVC which is a bit confusing - std::string constraints; - m_interpreter->getConstraintLog(state, constraints, Interpreter::STP); - auto f = openTestFile("cvc", id); + if (message) { + auto f = openTestFile(suffix, id); if (f) - *f << constraints; + *f << message; } - if (WriteSMT2s) { - std::string constraints; - m_interpreter->getConstraintLog(state, constraints, Interpreter::SMTLIB2); - auto f = openTestFile("smt2", id); - if (f) - *f << constraints; + if (m_pathWriter) { + std::vector concreteBranches; + m_pathWriter->readStream(m_interpreter->getPathStreamID(state), + concreteBranches); + auto f = openTestFile("path", id); + if (f) { + for (const auto &branch : concreteBranches) { + *f << branch << '\n'; + } + } } if (m_symPathWriter) { @@ -677,48 +663,14 @@ void KleeHandler::processTestCase(const ExecutionState &state, } } - if (WriteKPaths) { - std::string blockPath; - m_interpreter->getBlockPath(state, blockPath); - auto f = openTestFile("kpath", id); - if (f) - *f << blockPath; - } - - if (WriteCov) { - std::map> cov; - m_interpreter->getCoveredLines(state, cov); - auto f = openTestFile("cov", id); - if (f) { - for (const auto &entry : cov) { - for (const auto &line : entry.second) { - *f << entry.first << ':' << line << '\n'; - } - } - } - } - - if (WriteXMLTests) { - for (unsigned i = 0; i < ktest.uninitCoeff + 1; ++i) { - writeTestCaseXML(message != nullptr, ktest, id, i); - atLeastOneGenerated = true; - } - } - if (atLeastOneGenerated) { ++m_numGeneratedTests; } - for (unsigned i = 0; i < ktest.numObjects; i++) { - delete[] ktest.objects[i].bytes; - delete[] ktest.objects[i].pointers; - } - delete[] ktest.objects; - if (m_numGeneratedTests == MaxTests) m_interpreter->setHaltExecution(HaltExecution::MaxTests); - if (!WriteXMLTests && WriteTestInfo) { + if (WriteTestInfo) { time::Span elapsed_time(time::getWallTime() - start_time); auto f = openTestFile("info", id); if (f) @@ -726,6 +678,53 @@ void KleeHandler::processTestCase(const ExecutionState &state, } } // if (!WriteNone) + if (WriteKQueries) { + std::string constraints; + m_interpreter->getConstraintLog(state, constraints, Interpreter::KQUERY); + auto f = openTestFile("kquery", id); + if (f) + *f << constraints; + } + + if (WriteCVCs) { + // FIXME: If using Z3 as the core solver the emitted file is actually + // SMT-LIBv2 not CVC which is a bit confusing + std::string constraints; + m_interpreter->getConstraintLog(state, constraints, Interpreter::STP); + auto f = openTestFile("cvc", id); + if (f) + *f << constraints; + } + + if (WriteSMT2s) { + std::string constraints; + m_interpreter->getConstraintLog(state, constraints, Interpreter::SMTLIB2); + auto f = openTestFile("smt2", id); + if (f) + *f << constraints; + } + + if (WriteKPaths) { + std::string blockPath; + m_interpreter->getBlockPath(state, blockPath); + auto f = openTestFile("kpath", id); + if (f) + *f << blockPath; + } + + if (WriteCov) { + std::map> cov; + m_interpreter->getCoveredLines(state, cov); + auto f = openTestFile("cov", id); + if (f) { + for (const auto &entry : cov) { + for (const auto &line : entry.second) { + *f << entry.first << ':' << line << '\n'; + } + } + } + } + if (isError && OptExitOnError) { m_interpreter->prepareForEarlyExit(); klee_error("EXITING ON ERROR:\n%s\n", message); From d00e993956be7d304dd5e6f6355caf8c140dd461 Mon Sep 17 00:00:00 2001 From: Aleksandr Misonizhnik Date: Wed, 25 Oct 2023 00:29:29 +0400 Subject: [PATCH 3/5] [fix] --- lib/Module/KModule.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/Module/KModule.cpp b/lib/Module/KModule.cpp index 343974b472..dd43b00adc 100644 --- a/lib/Module/KModule.cpp +++ b/lib/Module/KModule.cpp @@ -282,16 +282,6 @@ void KModule::optimiseAndPrepare( pm.run(*module); } - if (opts.Optimize) - Optimize(module.get(), preservedFunctions); - - // Add internal functions which are not used to check if instructions - // have been already visited - if (opts.CheckDivZero) - addInternalFunction("klee_div_zero_check"); - if (opts.CheckOvershift) - addInternalFunction("klee_overshift_check"); - // Use KLEE's internal float classification functions if requested. if (opts.WithFPRuntime) { if (UseKleeFloatInternals) { @@ -304,6 +294,16 @@ void KModule::optimiseAndPrepare( } } + if (opts.Optimize) + Optimize(module.get(), preservedFunctions); + + // Add internal functions which are not used to check if instructions + // have been already visited + if (opts.CheckDivZero) + addInternalFunction("klee_div_zero_check"); + if (opts.CheckOvershift) + addInternalFunction("klee_overshift_check"); + // Needs to happen after linking (since ctors/dtors can be modified) // and optimization (since global optimization can rewrite lists). injectStaticConstructorsAndDestructors(module.get(), opts.EntryPoint); From 86628e8e5f94e99030585624db6f1bc34c0b5fba Mon Sep 17 00:00:00 2001 From: Sergey Morozov Date: Mon, 13 Mar 2023 20:52:25 +0300 Subject: [PATCH 4/5] [feat] FastCex --- include/klee/Expr/ExprRangeEvaluator.h | 63 +-- lib/Solver/FastCexSolver.cpp | 593 ++++++++++++++----------- 2 files changed, 372 insertions(+), 284 deletions(-) diff --git a/include/klee/Expr/ExprRangeEvaluator.h b/include/klee/Expr/ExprRangeEvaluator.h index 9148d5feb1..5fd2c6b72c 100644 --- a/include/klee/Expr/ExprRangeEvaluator.h +++ b/include/klee/Expr/ExprRangeEvaluator.h @@ -13,6 +13,8 @@ #include "klee/ADT/Bits.h" #include "klee/Expr/Expr.h" +#include "llvm/ADT/APInt.h" + namespace klee { /* @@ -90,7 +92,7 @@ T ExprRangeEvaluator::evalRead(const UpdateList &ul, T index) { template T ExprRangeEvaluator::evaluate(const ref &e) { switch (e->getKind()) { case Expr::Constant: - return T(cast(e)); + return T(cast(e)->getAPValue()); case Expr::NotOptimized: break; @@ -109,9 +111,9 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { const SelectExpr *se = cast(e); T cond = evaluate(se->cond); - if (cond.mustEqual(1)) { + if (cond.mustEqual(llvm::APInt(se->cond->getWidth(), 1))) { return evaluate(se->trueExpr); - } else if (cond.mustEqual(0)) { + } else if (cond.mustEqual(llvm::APInt(se->cond->getWidth(), 0))) { return evaluate(se->falseExpr); } else { return evaluate(se->trueExpr).set_union(evaluate(se->falseExpr)); @@ -120,11 +122,9 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { // XXX these should be unrolled to ensure nice inline case Expr::Concat: { - const Expr *ep = e.get(); - T res(0); - for (unsigned i = 0; i < ep->getNumKids(); i++) - res = res.concat(evaluate(ep->getKid(i)), 8); - return res; + ref ce = cast(e); + return evaluate(ce->getLeft()) + .concat(evaluate(ce->getRight()), ce->getRight()->getWidth()); } // Arithmetic @@ -206,9 +206,9 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { T right = evaluate(be->right); if (left.mustEqual(right)) { - return T(1); + return T(llvm::APInt(Expr::Bool, 1)); } else if (!left.mayEqual(right)) { - return T(0); + return T(llvm::APInt(Expr::Bool, 0)); } break; } @@ -218,10 +218,10 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { T left = evaluate(be->left); T right = evaluate(be->right); - if (left.max() < right.min()) { - return T(1); - } else if (left.min() >= right.max()) { - return T(0); + if (left.max().ult(right.min())) { + return T(llvm::APInt(Expr::Bool, 1)); + } else if (left.min().uge(right.max())) { + return T(llvm::APInt(Expr::Bool, 0)); } break; } @@ -230,10 +230,10 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { T left = evaluate(be->left); T right = evaluate(be->right); - if (left.max() <= right.min()) { - return T(1); - } else if (left.min() > right.max()) { - return T(0); + if (left.max().ule(right.min())) { + return T(llvm::APInt(Expr::Bool, 1)); + } else if (left.min().ugt(right.max())) { + return T(llvm::APInt(Expr::Bool, 0)); } break; } @@ -243,10 +243,10 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { T right = evaluate(be->right); unsigned bits = be->left->getWidth(); - if (left.maxSigned(bits) < right.minSigned(bits)) { - return T(1); - } else if (left.minSigned(bits) >= right.maxSigned(bits)) { - return T(0); + if (left.maxSigned(bits).ult(right.minSigned(bits))) { + return T(llvm::APInt(Expr::Bool, 1)); + } else if (left.minSigned(bits).uge(right.maxSigned(bits))) { + return T(llvm::APInt(Expr::Bool, 0)); } break; } @@ -256,13 +256,21 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { T right = evaluate(be->right); unsigned bits = be->left->getWidth(); - if (left.maxSigned(bits) <= right.minSigned(bits)) { - return T(1); - } else if (left.minSigned(bits) > right.maxSigned(bits)) { - return T(0); + if (left.maxSigned(bits).ule(right.minSigned(bits))) { + return T(llvm::APInt(Expr::Bool, 1)); + } else if (left.minSigned(bits).ugt(right.maxSigned(bits))) { + return T(llvm::APInt(Expr::Bool, 0)); } break; } + case Expr::ZExt: { + ref ce = cast(e); + return evaluate(e->getKid(0)).zextOrTrunc(ce->getWidth()); + } + case Expr::SExt: { + ref ce = cast(e); + return evaluate(e->getKid(0)).sextOrTrunc(ce->getWidth()); + } case Expr::Ne: case Expr::Ugt: @@ -275,7 +283,8 @@ template T ExprRangeEvaluator::evaluate(const ref &e) { break; } - return T(0, bits64::maxValueOfNBits(e->getWidth())); + return T(llvm::APInt(e->getWidth(), 0), + llvm::APInt::getAllOnesValue(e->getWidth())); } } // namespace klee diff --git a/lib/Solver/FastCexSolver.cpp b/lib/Solver/FastCexSolver.cpp index a878bc34bf..83fcbd814b 100644 --- a/lib/Solver/FastCexSolver.cpp +++ b/lib/Solver/FastCexSolver.cpp @@ -10,6 +10,7 @@ #define DEBUG_TYPE "cex-solver" #include "klee/Solver/Solver.h" +#include "klee/ADT/SparseStorage.h" #include "klee/Expr/Constraints.h" #include "klee/Expr/Expr.h" #include "klee/Expr/ExprEvaluator.h" @@ -50,85 +51,104 @@ cl::opt FastCexFor( } // namespace // Hacker's Delight, pgs 58-63 -static uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { - uint64_t temp, m = ((uint64_t)1) << 63; - while (m) { - if (~a & c & m) { - temp = (a | m) & -m; - if (temp <= b) { +static llvm::APInt minOR(llvm::APInt a, llvm::APInt b, llvm::APInt c, + llvm::APInt d) { + assert(a.getBitWidth() == c.getBitWidth()); + + llvm::APInt m = + llvm::APInt::getOneBitSet(a.getBitWidth(), a.getBitWidth() - 1); + while (m.getBoolValue()) { + if ((a.reverseBits() & c & m).getBoolValue()) { + llvm::APInt temp = (a | m) & -m; + if (temp.ule(b)) { a = temp; break; } - } else if (a & ~c & m) { - temp = (c | m) & -m; - if (temp <= d) { + } else if ((a & c.reverseBits() & m).getBoolValue()) { + llvm::APInt temp = (c | m) & -m; + if (temp.ule(d)) { c = temp; break; } } - m >>= 1; + m = m.lshr(1); } return a | c; } -static uint64_t maxOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { - uint64_t temp, m = ((uint64_t)1) << 63; +static llvm::APInt maxOR(llvm::APInt a, llvm::APInt b, llvm::APInt c, + llvm::APInt d) { + assert(a.getBitWidth() == c.getBitWidth()); + + llvm::APInt m = + llvm::APInt::getOneBitSet(a.getBitWidth(), a.getBitWidth() - 1); - while (m) { - if (b & d & m) { - temp = (b - m) | (m - 1); - if (temp >= a) { + while (m.getBoolValue()) { + if ((b & d & m).getBoolValue()) { + llvm::APInt temp = (b - m) | (m - 1); + if (temp.uge(a)) { b = temp; break; } temp = (d - m) | (m - 1); - if (temp >= c) { + if (temp.uge(c)) { d = temp; break; } } - m >>= 1; + m = m.lshr(1); } return b | d; } -static uint64_t minAND(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { - uint64_t temp, m = ((uint64_t)1) << 63; - while (m) { - if (~a & ~c & m) { - temp = (a | m) & -m; - if (temp <= b) { + +static llvm::APInt minAND(llvm::APInt a, llvm::APInt b, llvm::APInt c, + llvm::APInt d) { + assert(a.getBitWidth() == c.getBitWidth()); + + llvm::APInt m = + llvm::APInt::getOneBitSet(a.getBitWidth(), a.getBitWidth() - 1); + + while (m.getBoolValue()) { + if ((~a & ~c & m).getBoolValue()) { + llvm::APInt temp = (a | m) & -m; + if (temp.ule(b)) { a = temp; break; } temp = (c | m) & -m; - if (temp <= d) { + if (temp.ule(d)) { c = temp; break; } } - m >>= 1; + m = m.lshr(1); } return a & c; } -static uint64_t maxAND(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { - uint64_t temp, m = ((uint64_t)1) << 63; - while (m) { - if (b & ~d & m) { - temp = (b & ~m) | (m - 1); - if (temp >= a) { +static llvm::APInt maxAND(llvm::APInt a, llvm::APInt b, llvm::APInt c, + llvm::APInt d) { + assert(a.getBitWidth() == c.getBitWidth()); + + llvm::APInt m = + llvm::APInt::getOneBitSet(a.getBitWidth(), a.getBitWidth() - 1); + + while (m.getBoolValue()) { + if ((b & ~d & m).getBoolValue()) { + llvm::APInt temp = (b & ~m) | (m - 1); + if (temp.uge(a)) { b = temp; break; } - } else if (~b & d & m) { - temp = (d & ~m) | (m - 1); - if (temp >= c) { + } else if ((~b & d & m).getBoolValue()) { + llvm::APInt temp = (d & ~m) | (m - 1); + if (temp.uge(c)) { d = temp; break; } } - m >>= 1; + m = m.lshr(1); } return b & d; @@ -138,18 +158,20 @@ static uint64_t maxAND(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { class ValueRange { private: - std::uint64_t m_min = 1, m_max = 0; + llvm::APInt m_min, m_max; + unsigned width; public: - ValueRange() noexcept = default; - ValueRange(const ref &ce) { - // FIXME: Support large widths. - m_min = m_max = ce->getLimitedValue(); - } - explicit ValueRange(std::uint64_t value) noexcept - : m_min(value), m_max(value) {} - ValueRange(std::uint64_t _min, std::uint64_t _max) noexcept - : m_min(_min), m_max(_max) {} + ValueRange() noexcept : width(m_min.getBitWidth()) {} + ValueRange(const ref &ce) : width(ce->getWidth()) { + m_min = m_max = ce->getAPValue(); + } + explicit ValueRange(const llvm::APInt &value) noexcept + : m_min(value), m_max(value), width(value.getBitWidth()) {} + ValueRange(const llvm::APInt &_min, const llvm::APInt &_max) noexcept + : m_min(_min), m_max(_max), width(m_min.getBitWidth()) { + assert(m_min.getBitWidth() == m_max.getBitWidth()); + } ValueRange(const ValueRange &other) noexcept = default; ValueRange &operator=(const ValueRange &other) noexcept = default; ValueRange(ValueRange &&other) noexcept = default; @@ -163,8 +185,10 @@ class ValueRange { } } - bool isEmpty() const noexcept { return m_min > m_max; } - bool contains(std::uint64_t value) const { + unsigned bitWidth() const { return width; } + + bool isEmpty() const noexcept { return m_min.ugt(m_max); } + bool contains(const llvm::APInt &value) const { return this->intersects(ValueRange(value)); } bool intersects(const ValueRange &b) const { @@ -172,24 +196,28 @@ class ValueRange { } bool isFullRange(unsigned bits) const noexcept { - return m_min == 0 && m_max == bits64::maxValueOfNBits(bits); + return m_min == 0 && m_max == llvm::APInt::getAllOnesValue(bits); } ValueRange set_intersection(const ValueRange &b) const { - return ValueRange(std::max(m_min, b.m_min), std::min(m_max, b.m_max)); + return ValueRange(llvm::APIntOps::umax(m_min, b.m_min), + llvm::APIntOps::umin(m_max, b.m_max)); } ValueRange set_union(const ValueRange &b) const { - return ValueRange(std::min(m_min, b.m_min), std::max(m_max, b.m_max)); + return ValueRange(llvm::APIntOps::umin(m_min, b.m_min), + llvm::APIntOps::umax(m_max, b.m_max)); } ValueRange set_difference(const ValueRange &b) const { - if (b.isEmpty() || b.m_min > m_max || b.m_max < m_min) { // no intersection + if (b.isEmpty() || b.m_min.ugt(m_max) || + b.m_max.ult(m_min)) { // no intersection return *this; - } else if (b.m_min <= m_min && b.m_max >= m_max) { // empty - return ValueRange(1, 0); - } else if (b.m_min <= m_min) { // one range out + } else if (b.m_min.ule(m_min) && b.m_max.uge(m_max)) { // empty + return ValueRange(llvm::APInt::getOneBitSet(width, 0), + llvm::APInt::getNullValue(width)); + } else if (b.m_min.ule(m_min)) { // one range out // cannot overflow because b.m_max < m_max return ValueRange(b.m_max + 1, m_max); - } else if (b.m_max >= m_max) { + } else if (b.m_max.uge(m_max)) { // cannot overflow because b.min > m_min return ValueRange(m_min, b.m_min - 1); } else { @@ -207,7 +235,7 @@ class ValueRange { maxAND(m_min, m_max, b.m_min, b.m_max)); } } - ValueRange binaryAnd(std::uint64_t b) const { + ValueRange binaryAnd(const llvm::APInt &b) const { return binaryAnd(ValueRange(b)); } ValueRange binaryOr(ValueRange b) const { @@ -220,15 +248,20 @@ class ValueRange { maxOR(m_min, m_max, b.m_min, b.m_max)); } } - ValueRange binaryOr(std::uint64_t b) const { return binaryOr(ValueRange(b)); } + ValueRange binaryOr(const llvm::APInt &b) const { + return binaryOr(ValueRange(b)); + } ValueRange binaryXor(ValueRange b) const { if (isFixed() && b.isFixed()) { return ValueRange(m_min ^ b.m_min); } else { - std::uint64_t t = m_max | b.m_max; - while (!bits64::isPowerOfTwo(t)) - t = bits64::withoutRightmostBit(t); - return ValueRange(0, (t << 1) - 1); + llvm::APInt t = m_max | b.m_max; + if (!t.isPowerOf2()) { + t = llvm::APInt::getOneBitSet(t.getBitWidth(), + t.getBitWidth() - t.countLeadingZeros()); + } + return ValueRange(llvm::APInt::getNullValue(t.getBitWidth()), + (t << 1) - 1); } } @@ -236,37 +269,70 @@ class ValueRange { return ValueRange(m_min << bits, m_max << bits); } ValueRange binaryShiftRight(unsigned bits) const { - return ValueRange(m_min >> bits, m_max >> bits); + return ValueRange(m_min.lshr(bits), m_max.lshr(bits)); } - ValueRange concat(const ValueRange &b, unsigned bits) const { - return binaryShiftLeft(bits).binaryOr(b); + ValueRange concat(ValueRange b, unsigned bits) const { + ValueRange newRange = + ValueRange(m_min.zext(bitWidth() + bits), m_max.zext(bitWidth() + bits)) + .binaryShiftLeft(bits); + b.m_min = b.m_min.zext(bitWidth() + bits); + b.m_max = b.m_max.zext(bitWidth() + bits); + return newRange.binaryOr(b); } ValueRange extract(std::uint64_t lowBit, std::uint64_t maxBit) const { - return binaryShiftRight(lowBit).binaryAnd( - bits64::maxValueOfNBits(maxBit - lowBit)); + assert(!isEmpty()); + ValueRange newRange = + binaryShiftRight(width - maxBit) + .binaryAnd(llvm::APInt::getLowBitsSet(width, maxBit - lowBit)); + newRange.width = maxBit - lowBit; + newRange.m_min = newRange.m_min.trunc(newRange.width); + newRange.m_max = newRange.m_max.trunc(newRange.width); + assert(!newRange.isEmpty()); + return newRange; } ValueRange add(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange sub(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange mul(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange udiv(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange sdiv(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange urem(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); } ValueRange srem(const ValueRange &b, unsigned width) const { - return ValueRange(0, bits64::maxValueOfNBits(width)); + return ValueRange(llvm::APInt::getNullValue(width), + llvm::APInt::getAllOnesValue(width)); + } + ValueRange zextOrTrunc(unsigned newWidth) const { + ValueRange zextOrTruncRange; + zextOrTruncRange.m_min = m_min.zextOrTrunc(newWidth); + zextOrTruncRange.m_max = m_max.zextOrTrunc(newWidth); + zextOrTruncRange.width = newWidth; + return zextOrTruncRange; + } + ValueRange sextOrTrunc(unsigned newWidth) const { + ValueRange sextOrTruncRange; + sextOrTruncRange.m_min = m_min.sextOrTrunc(newWidth); + sextOrTruncRange.m_max = m_max.sextOrTrunc(newWidth); + sextOrTruncRange.width = newWidth; + return sextOrTruncRange; } // use min() to get value if true (XXX should we add a method to @@ -278,11 +344,11 @@ class ValueRange { } bool operator!=(const ValueRange &b) const noexcept { return !(*this == b); } - bool mustEqual(const std::uint64_t b) const noexcept { + bool mustEqual(const llvm::APInt &b) const noexcept { return m_min == m_max && m_min == b; } - bool mayEqual(const std::uint64_t b) const noexcept { - return m_min <= b && m_max >= b; + bool mayEqual(const llvm::APInt &b) const noexcept { + return m_min.ule(b) && m_max.uge(b); } bool mustEqual(const ValueRange &b) const noexcept { @@ -290,48 +356,48 @@ class ValueRange { } bool mayEqual(const ValueRange &b) const { return this->intersects(b); } - std::uint64_t min() const noexcept { + llvm::APInt min() const noexcept { assert(!isEmpty() && "cannot get minimum of empty range"); return m_min; } - std::uint64_t max() const noexcept { + llvm::APInt max() const noexcept { assert(!isEmpty() && "cannot get maximum of empty range"); return m_max; } - std::int64_t minSigned(unsigned bits) const { - assert(bits >= 2 && bits <= 64); - assert((m_min >> bits) == 0 && (m_max >> bits) == 0 && - "range is outside given number of bits"); - + llvm::APInt minSigned(unsigned bits) const { // if max allows sign bit to be set then it can be smallest value, // otherwise since the range is not empty, min cannot have a sign // bit - std::uint64_t smallest = (static_cast(1) << (bits - 1)); - if (m_max >= smallest) { - return llvm::APInt::getSignedMinValue(bits).getSExtValue(); + llvm::APInt smallest = llvm::APInt::getSignedMinValue(bits); + + if (m_max.uge(smallest)) { + return m_max.sext(bits); } else { return m_min; } } - std::int64_t maxSigned(unsigned bits) const { - assert(bits >= 2 && bits <= 64); - assert((m_min >> bits) == 0 && (m_max >> bits) == 0 && - "range is outside given number of bits"); - - std::uint64_t smallest = (static_cast(1) << (bits - 1)); + // Works like a sext instrution: if bits is less then + // current width, then truncate expression; otherwise + // extend it to bits. + llvm::APInt maxSigned(unsigned bits) const { + llvm::APInt smallest = llvm::APInt::getSignedMinValue(bits); // if max and min have sign bit then max is max, otherwise if only // max has sign bit then max is largest signed integer, otherwise // max is max - if (m_min < smallest && m_max >= smallest) { + if (m_min.ult(smallest) && m_max.uge(smallest)) { return smallest - 1; } else { - return llvm::APInt(bits, m_max, true).getSExtValue(); + // width are not equal here; if this width is shorter, then + // we will return sign extended max, otherwise we need to find + // signed max value of first n bits + + return m_max.sext(bits); } } }; @@ -350,46 +416,45 @@ class CexObjectData { /// /// The possible values is an inexact approximation for the set of values for /// each array location. - std::vector possibleContents; + SparseStorage possibleContents; /// exactContents - An array of exact values for the object. /// /// The exact values are a conservative approximation for the set of values /// for each array location. - std::vector exactContents; + SparseStorage exactContents; CexObjectData(const CexObjectData &); // DO NOT IMPLEMENT void operator=(const CexObjectData &); // DO NOT IMPLEMENT public: - CexObjectData(uint64_t size) : possibleContents(size), exactContents(size) { - for (uint64_t i = 0; i != size; ++i) { - possibleContents[i] = ValueRange(0, 255); - exactContents[i] = ValueRange(0, 255); - } - } + CexObjectData(uint64_t size) + : possibleContents(ValueRange(llvm::APInt::getNullValue(CHAR_BIT), + llvm::APInt::getAllOnesValue(CHAR_BIT))), + exactContents(ValueRange(llvm::APInt::getNullValue(CHAR_BIT), + llvm::APInt::getAllOnesValue(CHAR_BIT))) {} const CexValueData getPossibleValues(size_t index) const { - return possibleContents[index]; + return possibleContents.load(index); } - void setPossibleValues(size_t index, CexValueData values) { - possibleContents[index] = values; + void setPossibleValues(size_t index, const CexValueData &values) { + possibleContents.store(index, values); } - void setPossibleValue(size_t index, unsigned char value) { - possibleContents[index] = CexValueData(value); + void setPossibleValue(size_t index, const llvm::APInt &value) { + possibleContents.store(index, CexValueData(value)); } const CexValueData getExactValues(size_t index) const { - return exactContents[index]; + return exactContents.load(index); } void setExactValues(size_t index, CexValueData values) { - exactContents[index] = values; + exactContents.store(index, values); } /// getPossibleValue - Return some possible value. - unsigned char getPossibleValue(size_t index) const { - const CexValueData &cvd = possibleContents[index]; - return cvd.min() + (cvd.max() - cvd.min()) / 2; + llvm::APInt getPossibleValue(size_t index) const { + CexValueData cvd = possibleContents.load(index); + return cvd.min() + (cvd.max() - cvd.min()).lshr(1); } }; @@ -404,12 +469,14 @@ class CexRangeEvaluator : public ExprRangeEvaluator { if (array.isConstantArray() && index.isFixed()) { if (ref constantSource = dyn_cast(array.source)) { - if (auto value = constantSource->constantValues.load(index.min())) { - return ValueRange(value->getZExtValue(8)); + if (auto value = constantSource->constantValues.load( + index.min().getZExtValue())) { + return ValueRange(value->getAPValue()); } } } - return ValueRange(0, 255); + return ValueRange(llvm::APInt::getNullValue(CHAR_BIT), + llvm::APInt::getAllOnesValue(CHAR_BIT)); } }; @@ -426,7 +493,7 @@ class CexPossibleEvaluator : public ExprEvaluator { std::abort(); } - if (index >= constantArraySize->getZExtValue()) { + if (!constantArraySize || index >= constantArraySize->getZExtValue()) { return ReadExpr::create(UpdateList(&array, 0), ConstantExpr::alloc(index, array.getDomain())); } @@ -434,7 +501,9 @@ class CexPossibleEvaluator : public ExprEvaluator { std::map::iterator it = objects.find(&array); return ConstantExpr::alloc( - (it == objects.end() ? 127 : it->second->getPossibleValue(index)), + (it == objects.end() + ? 127 + : it->second->getPossibleValue(index).getZExtValue()), array.getRange()); } @@ -456,7 +525,7 @@ class CexExactEvaluator : public ExprEvaluator { ConstantExpr::alloc(index, array.getDomain())); } - if (index >= constantArraySize->getZExtValue()) { + if (!constantArraySize || index >= constantArraySize->getZExtValue()) { return ReadExpr::create(UpdateList(&array, 0), ConstantExpr::alloc(index, array.getDomain())); } @@ -472,7 +541,7 @@ class CexExactEvaluator : public ExprEvaluator { return ReadExpr::create(UpdateList(&array, 0), ConstantExpr::alloc(index, array.getDomain())); - return ConstantExpr::alloc(cvd.min(), array.getRange()); + return ConstantExpr::create(cvd.min().getZExtValue(), array.getRange()); } public: @@ -515,22 +584,26 @@ class CexData { return *Entry; } - void propagatePossibleValue(ref e, uint64_t value) { - propagatePossibleValues(e, CexValueData(value, value)); + void propagatePossibleValue(ref e, const llvm::APInt &value) { + propagatePossibleValues(e, CexValueData(value)); } - void propagateExactValue(ref e, uint64_t value) { - propagateExactValues(e, CexValueData(value, value)); + void propogateExactValue(ref e, const llvm::APInt &value) { + propagateExactValues(e, CexValueData(value)); } void propagatePossibleValues(ref e, CexValueData range) { - KLEE_DEBUG(llvm::errs() << "propagate: " << range << " for\n" << e << "\n"); + assert(range.bitWidth() == e->getWidth()); switch (e->getKind()) { - case Expr::Constant: + case Expr::Constant: { + ref CE = cast(e); + assert(range.intersects(ValueRange(CE->getAPValue())) && + "Constant is out of range for propagation."); // rather a pity if the constant isn't in the range, but how can // we use this? break; + } // Special @@ -575,56 +648,22 @@ class CexData { SelectExpr *se = cast(e); ValueRange cond = evalRangeForExpr(se->cond); if (cond.isFixed()) { - if (cond.min()) { + if (cond.min().getBoolValue()) { propagatePossibleValues(se->trueExpr, range); } else { propagatePossibleValues(se->falseExpr, range); } - } else { - // XXX imprecise... we have a choice here. One method is to - // simply force both sides into the specified range (since the - // condition is indetermined). This may lose in two ways, the - // first is that the condition chosen may limit further - // restrict the range in each of the children, however this is - // less of a problem as the range will be a superset of legal - // values. The other is if the condition ends up being forced - // by some other constraints, then we needlessly forced one - // side into the given range. - // - // The other method would be to force the condition to one - // side and force that side into the given range. This loses - // when we force the condition to an unsatisfiable value - // (either because the condition cannot be that, or the - // resulting range given that condition is not in the required - // range). - // - // Currently we just force both into the range. A hybrid would - // be to evaluate the ranges for each of the children... if - // one of the ranges happens to already be a subset of the - // required range then it may be preferable to force the - // condition to that side. - propagatePossibleValues(se->trueExpr, range); - propagatePossibleValues(se->falseExpr, range); } break; } - // XXX imprecise... the problem here is that extracting bits - // loses information about what bits are connected across the - // bytes. if a value can be 1 or 256 then either the top or - // lower byte is 0, but just extraction loses this information - // and will allow neither,one,or both to be 1. - // - // we can protect against this in a limited fashion by writing - // the extraction a byte at a time, then checking the evaluated - // value, isolating for that range, and continuing. case Expr::Concat: { ConcatExpr *ce = cast(e); - Expr::Width LSBWidth = ce->getKid(1)->getWidth(); - Expr::Width MSBWidth = ce->getKid(1)->getWidth(); - propagatePossibleValues(ce->getKid(0), + Expr::Width LSBWidth = ce->getLeft()->getWidth(); + Expr::Width MSBWidth = ce->getRight()->getWidth(); + propagatePossibleValues(ce->getLeft(), range.extract(0, LSBWidth)); + propagatePossibleValues(ce->getRight(), range.extract(LSBWidth, LSBWidth + MSBWidth)); - propagatePossibleValues(ce->getKid(1), range.extract(0, LSBWidth)); break; } @@ -643,8 +682,16 @@ class CexData { case Expr::ZExt: { CastExpr *ce = cast(e); unsigned inBits = ce->src->getWidth(); - ValueRange input = range.set_intersection( - ValueRange(0, bits64::maxValueOfNBits(inBits))); + unsigned outBits = ce->getWidth(); + + // Intersect with range of same bitness and truncate + // result to inBits (as llvm::APInt can not be compared + // if they have different width). + ValueRange input = range + .set_intersection(ValueRange( + llvm::APInt::getNullValue(outBits), + llvm::APInt::getLowBitsSet(outBits, inBits))) + .zextOrTrunc(inBits); propagatePossibleValues(ce->src, input); break; } @@ -655,10 +702,15 @@ class CexData { CastExpr *ce = cast(e); unsigned inBits = ce->src->getWidth(); unsigned outBits = ce->width; - ValueRange output = range.set_difference(ValueRange( - 1 << (inBits - 1), (bits64::maxValueOfNBits(outBits) - - bits64::maxValueOfNBits(inBits - 1) - 1))); - ValueRange input = output.binaryAnd(bits64::maxValueOfNBits(inBits)); + + ValueRange input = + range + .set_difference(ValueRange( + llvm::APInt::getOneBitSet(outBits, inBits - 1), + (llvm::APInt::getAllOnesValue(outBits) - + llvm::APInt::getLowBitsSet(outBits, inBits - 1) - 1))) + .zextOrTrunc(inBits); + propagatePossibleValues(ce->src, input); break; } @@ -668,20 +720,16 @@ class CexData { case Expr::Add: { BinaryExpr *be = cast(e); if (ConstantExpr *CE = dyn_cast(be->left)) { - // FIXME: Don't depend on width. - if (CE->getWidth() <= 64) { - // FIXME: Why do we ever propagate empty ranges? It doesn't make - // sense. - if (range.isEmpty()) - break; - - // C_0 + X \in [MIN, MAX) ==> X \in [MIN - C_0, MAX - C_0) - Expr::Width W = CE->getWidth(); - CexValueData nrange( - ConstantExpr::alloc(range.min(), W)->Sub(CE)->getZExtValue(), - ConstantExpr::alloc(range.max(), W)->Sub(CE)->getZExtValue()); - if (!nrange.isEmpty()) - propagatePossibleValues(be->right, nrange); + // FIXME: Why do we ever propogate empty ranges? It doesn't make + // sense. + if (range.isEmpty()) + break; + + // C_0 + X \in [MIN, MAX) ==> X \in [MIN - C_0, MAX - C_0) + CexValueData nrange(range.min() - CE->getAPValue(), + range.max() - CE->getAPValue()); + if (!nrange.isEmpty()) { + propagatePossibleValues(be->right, nrange); } } break; @@ -695,23 +743,31 @@ class CexData { ValueRange right = evalRangeForExpr(be->right); if (!range.min()) { - if (left.mustEqual(0) || right.mustEqual(0)) { + if (left.mustEqual(llvm::APInt::getNullValue(be->getWidth())) || + right.mustEqual(llvm::APInt::getNullValue(be->getWidth()))) { // all is well } else { // XXX heuristic, which order - propagatePossibleValue(be->left, 0); + propagatePossibleValue( + be->left, llvm::APInt::getNullValue(be->left->getWidth())); left = evalRangeForExpr(be->left); // see if that worked - if (!left.mustEqual(1)) - propagatePossibleValue(be->right, 0); + if (!left.mustEqual(llvm::APInt(be->left->getWidth(), 1))) { + propagatePossibleValue( + be->right, llvm::APInt::getNullValue(be->left->getWidth())); + } } } else { - if (!left.mustEqual(1)) - propagatePossibleValue(be->left, 1); - if (!right.mustEqual(1)) - propagatePossibleValue(be->right, 1); + llvm::APInt leftAPIntOne = llvm::APInt(be->left->getWidth(), 1); + if (!left.mustEqual(leftAPIntOne)) { + propagatePossibleValue(be->left, leftAPIntOne); + } + llvm::APInt rightAPIntOne = llvm::APInt(be->right->getWidth(), 1); + if (!right.mustEqual(rightAPIntOne)) { + propagatePossibleValue(be->right, rightAPIntOne); + } } } } else { @@ -727,25 +783,29 @@ class CexData { ValueRange left = evalRangeForExpr(be->left); ValueRange right = evalRangeForExpr(be->right); - if (range.min()) { - if (left.mustEqual(1) || right.mustEqual(1)) { + llvm::APInt zeroAPInt = + llvm::APInt::getNullValue(be->left->getWidth()); + llvm::APInt oneAPInt = llvm::APInt(be->left->getWidth(), 1); + + if (range.min().getBoolValue()) { + if (left.mustEqual(oneAPInt) || right.mustEqual(oneAPInt)) { // all is well } else { // XXX heuristic, which order? // force left to value we need - propagatePossibleValue(be->left, 1); + propagatePossibleValue(be->left, oneAPInt); left = evalRangeForExpr(be->left); // see if that worked - if (!left.mustEqual(1)) - propagatePossibleValue(be->right, 1); + if (!left.mustEqual(oneAPInt)) + propagatePossibleValue(be->right, oneAPInt); } } else { - if (!left.mustEqual(0)) - propagatePossibleValue(be->left, 0); - if (!right.mustEqual(0)) - propagatePossibleValue(be->right, 0); + if (!left.mustEqual(zeroAPInt)) + propagatePossibleValue(be->left, zeroAPInt); + if (!right.mustEqual(zeroAPInt)) + propagatePossibleValue(be->right, zeroAPInt); } } } else { @@ -763,25 +823,26 @@ class CexData { BinaryExpr *be = cast(e); if (range.isFixed()) { if (ConstantExpr *CE = dyn_cast(be->left)) { - // FIXME: Handle large widths? - if (CE->getWidth() <= 64) { - uint64_t value = CE->getZExtValue(); - if (range.min()) { - propagatePossibleValue(be->right, value); + llvm::APInt value = CE->getAPValue(); + if (range.min().getBoolValue()) { + propagatePossibleValue(be->right, value); + } else { + CexValueData range; + if (value == 0) { + range = + CexValueData(llvm::APInt(CE->getWidth(), 1), + llvm::APInt::getAllOnesValue(CE->getWidth())); } else { - CexValueData range; - if (value == 0) { - range = - CexValueData(1, bits64::maxValueOfNBits(CE->getWidth())); - } else { - // FIXME: heuristic / lossy, could be better to pick larger - // range? - range = CexValueData(0, value - 1); - } - propagatePossibleValues(be->right, range); + // FIXME: heuristic / lossy, could be better to pick larger + // range? + + // FIXME: choose both + // range = CexValueData(llvm::APInt::getNullValue(CE->getWidth()), + // value - 1); + // range = CexValueData( + // value + 1, llvm::APInt::getAllOnesValue(CE->getWidth())); } - } else { - // XXX what now + propagatePossibleValues(be->right, range); } } } @@ -790,7 +851,8 @@ class CexData { case Expr::Not: { if (e->getWidth() == Expr::Bool && range.isFixed()) { - propagatePossibleValue(e->getKid(0), !range.min()); + propagatePossibleValue( + e->getKid(0), llvm::APInt(e->getKid(0)->getWidth(), !range.min())); } break; } @@ -804,20 +866,27 @@ class CexData { ValueRange left = evalRangeForExpr(be->left); ValueRange right = evalRangeForExpr(be->right); - uint64_t maxValue = bits64::maxValueOfNBits(be->right->getWidth()); + llvm::APInt maxValue = + llvm::APInt::getAllOnesValue(be->right->getWidth()); // XXX should deal with overflow (can lead to empty range) if (left.isFixed()) { - if (range.min()) { + if (!range.min().isNullValue()) { propagatePossibleValues(be->right, CexValueData(left.min() + 1, maxValue)); } else { - propagatePossibleValues(be->right, CexValueData(0, left.min())); + propagatePossibleValues( + be->right, + CexValueData(llvm::APInt::getNullValue(be->right->getWidth()), + left.min())); } } else if (right.isFixed()) { - if (range.min()) { - propagatePossibleValues(be->left, CexValueData(0, right.min() - 1)); + if (!range.min().isNullValue()) { + propagatePossibleValues( + be->left, + CexValueData(llvm::APInt::getNullValue(be->right->getWidth()), + right.min() - 1)); } else { propagatePossibleValues(be->left, CexValueData(right.min(), maxValue)); @@ -839,23 +908,31 @@ class CexData { // XXX should deal with overflow (can lead to empty range) - uint64_t maxValue = bits64::maxValueOfNBits(be->right->getWidth()); + llvm::APInt maxValue = + llvm::APInt::getAllOnesValue(be->right->getWidth()); if (left.isFixed()) { - if (range.min()) { + if (range.min().getBoolValue()) { propagatePossibleValues(be->right, CexValueData(left.min(), maxValue)); } else { - propagatePossibleValues(be->right, CexValueData(0, left.min() - 1)); + propagatePossibleValues( + be->right, + CexValueData(llvm::APInt::getNullValue(be->right->getWidth()), + left.min() - 1)); } } else if (right.isFixed()) { - if (range.min()) { - propagatePossibleValues(be->left, CexValueData(0, right.min())); + if (range.min().getBoolValue()) { + propagatePossibleValues( + be->left, + CexValueData(llvm::APInt::getNullValue(be->right->getWidth()), + right.min())); } else { propagatePossibleValues(be->left, CexValueData(right.min() + 1, maxValue)); } } else { // XXX ??? + // TODO: we can try to order it! } } break; @@ -912,19 +989,20 @@ class CexData { if (!isa(array->size)) { assert(0 && "Unimplemented"); } - propagateExactValues(constantSource->constantValues.load(index.min()), - range); + propagateExactValues( + constantSource->constantValues.load(index.min().getZExtValue()), + range); } else { - CexValueData cvd = cod.getExactValues(index.min()); - if (range.min() > cvd.min()) { - assert(range.min() <= cvd.max()); + CexValueData cvd = cod.getExactValues(index.min().getZExtValue()); + if (range.min().ugt(cvd.min())) { + assert(range.min().ule(cvd.max())); cvd = CexValueData(range.min(), cvd.max()); } - if (range.max() < cvd.max()) { - assert(range.max() >= cvd.min()); + if (range.max().ult(cvd.max())) { + assert(range.max().uge(cvd.min())); cvd = CexValueData(cvd.min(), range.max()); } - cod.setExactValues(index.min(), cvd); + cod.setExactValues(index.min().getZExtValue(), cvd); } } break; @@ -972,17 +1050,16 @@ class CexData { BinaryExpr *be = cast(e); if (range.isFixed()) { if (ConstantExpr *CE = dyn_cast(be->left)) { - // FIXME: Handle large widths? - if (CE->getWidth() <= 64) { - uint64_t value = CE->getZExtValue(); - if (range.min()) { - // If the equality is true, then propagate the value. - propagateExactValue(be->right, value); - } else { - // If the equality is false and the comparison is of booleans, - // then we can infer the value to propagate. - if (be->right->getWidth() == Expr::Bool) - propagateExactValue(be->right, !value); + if (range.min().getBoolValue()) { + // If the equality is true, then propogate the value. + propogateExactValue(be->right, CE->getAPValue()); + } else { + // If the equality is false and the comparison is of booleans, + // then we can infer the value to propogate. + if (be->right->getWidth() == Expr::Bool) { + propogateExactValue( + be->right, + llvm::APInt(Expr::Bool, !CE->getAPValue().getBoolValue())); } } } @@ -993,7 +1070,9 @@ class CexData { // If a boolean not, and the result is known, propagate it case Expr::Not: { if (e->getWidth() == Expr::Bool && range.isFixed()) { - propagateExactValue(e->getKid(0), !range.min()); + llvm::APInt propValue = + llvm::APInt(e->getWidth(), !range.min().getBoolValue()); + propogateExactValue(e->getKid(0), propValue); } break; } @@ -1103,16 +1182,17 @@ FastCexSolver::~FastCexSolver() {} static bool propagateValues(const Query &query, CexData &cd, bool checkExpr, bool &isValid) { for (const auto &constraint : query.constraints.cs()) { - cd.propagatePossibleValue(constraint, 1); - cd.propagateExactValue(constraint, 1); + cd.propagatePossibleValue(constraint, + llvm::APInt(constraint->getWidth(), 1)); + cd.propogateExactValue(constraint, llvm::APInt(constraint->getWidth(), 1)); } if (checkExpr) { - cd.propagatePossibleValue(query.expr, 0); - cd.propagateExactValue(query.expr, 0); + cd.propagatePossibleValue( + query.expr, llvm::APInt::getNullValue(query.expr->getWidth())); + cd.propogateExactValue(query.expr, + llvm::APInt::getNullValue(query.expr->getWidth())); } - KLEE_DEBUG(cd.dump()); - // Check the result. bool hasSatisfyingAssignment = true; if (checkExpr) { @@ -1188,7 +1268,6 @@ bool FastCexSolver::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { CexData cd; - query.dump(); bool isValid; bool success = propagateValues(query, cd, true, isValid); From 5e533ee6f73b47fd5789276665f3d5f9c0f80f95 Mon Sep 17 00:00:00 2001 From: Aleksandr Misonizhnik Date: Sat, 28 Oct 2023 17:57:54 +0400 Subject: [PATCH 5/5] [fix] --- lib/Solver/FastCexSolver.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Solver/FastCexSolver.cpp b/lib/Solver/FastCexSolver.cpp index 83fcbd814b..219b02cebf 100644 --- a/lib/Solver/FastCexSolver.cpp +++ b/lib/Solver/FastCexSolver.cpp @@ -837,8 +837,8 @@ class CexData { // range? // FIXME: choose both - // range = CexValueData(llvm::APInt::getNullValue(CE->getWidth()), - // value - 1); + range = CexValueData(llvm::APInt::getNullValue(CE->getWidth()), + value - 1); // range = CexValueData( // value + 1, llvm::APInt::getAllOnesValue(CE->getWidth())); }