Skip to content

Commit

Permalink
feat(search): Multishard cutoffs
Browse files Browse the repository at this point in the history
  • Loading branch information
dranikpg committed Sep 24, 2023
1 parent d8b99dc commit 24b2025
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 138 deletions.
5 changes: 5 additions & 0 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ struct AstNode : public NodeVariants {
const NodeVariants& Variant() const& {
return *this;
}

// Aggregations reduce and re-order result sets.
bool IsAggregation() const {
return std::holds_alternative<AstKnnNode>(Variant());
}
};

using AstExpr = AstNode;
Expand Down
8 changes: 5 additions & 3 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,13 @@ template <typename RandGen> std::string GetRandomHex(RandGen& gen, size_t len) {
// truthy value;
template <typename T> struct AggregateValue {
bool operator=(T val) {
if (!bool(val))
return false;

std::lock_guard l{mu_};
if (!bool(current_) && bool(val)) {
if (!bool(current_))
current_ = val;
}
return bool(val);
return true;
}

T operator*() {
Expand Down
96 changes: 75 additions & 21 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
{"NUMERIC"sv, search::SchemaField::NUMERIC},
{"VECTOR"sv, search::SchemaField::VECTOR}};

size_t GetProbabilisticBound(size_t shards, size_t hits, size_t requested, bool is_aggregation) {
auto intlog2 = [](size_t x) {
size_t l = 0;
while (x >>= 1)
++l;
return l;
};
size_t avg_shard_min = hits * intlog2(hits) / (12 + shards / 10);
avg_shard_min -= min(avg_shard_min, min(hits, size_t(5)));

// VLOG(0) << "PROB BOUND " << hits << " " << shards << " " << requested <<
// " => " << avg_shard_min << " diffb " << requested / shards + 1 << " & " << requested;

if (!is_aggregation && avg_shard_min * shards >= requested)
return requested / shards + 1;

return min(hits, requested);
}

} // namespace

optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
Expand Down Expand Up @@ -149,10 +168,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
}

ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
: base_{std::move(index)}, indices_{{}}, key_index_{} {
: base_{std::move(index)}, write_epoch_{0}, indices_{{}}, key_index_{} {
}

void ShardDocIndex::Rebuild(const OpArgs& op_args) {
write_epoch_++;
key_index_ = DocKeyIndex{};
indices_ = search::FieldIndices{base_->schema};

Expand All @@ -161,11 +181,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args) {
}

void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
indices_.Add(key_index_.Add(key), accessor.get());
}

void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_.Remove(id, accessor.get());
Expand All @@ -175,38 +197,70 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.shard->db_slice();
io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search(
const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
auto search_results = search_algo->Search(&indices_);

if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error)));

size_t serialize_count = min(search_results.ids.size(), params.limit_offset + params.limit_total);
vector<SerializedSearchDoc> out;
out.reserve(serialize_count);
size_t requested_count = params.limit_offset + params.limit_total;
size_t serialize_count = min(requested_count, search_results.ids.size());

size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size() && out.size() < serialize_count; i++) {
auto key = key_index_.Get(search_results.ids[i]);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
size_t cuttoff_bound = serialize_count;
if (params.enable_cutoff && !params.IdsOnly())
serialize_count = GetProbabilisticBound(params.num_shards, search_results.ids.size(),
requested_count, search_algo->HasKnn().has_value());

vector<DocResult> out(serialize_count);
auto shard_id = EngineShard::tlocal()->shard_id();
for (size_t i = 0; i < out.size(); i++) {
out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound};
out[i].score = search_results.knn_distances.empty() ? 0 : search_results.knn_distances[i];
}

Serialize(op_args, params, absl::MakeSpan(out));

return SearchResult{write_epoch_, search_results.ids.size(), std::move(out),
std::move(search_results.profile)};
}

bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const {
if (result->write_epoch == write_epoch_) {
Serialize(op_args, params, absl::MakeSpan(result->docs));
return true;
}

DCHECK(!params.enable_cutoff);
auto new_result = Search(op_args, params, search_algo);
CHECK(new_result.has_value());
*result = std::move(new_result.value());
return false;
}

void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const {
auto& db_slice = op_args.shard->db_slice();

if (!it || !IsValid(*it)) { // Item must have expired
expired_count++;
for (auto& doc : docs) {
if (!holds_alternative<DocResult::DocReference>(doc.value))
continue;

auto ref = get<DocResult::DocReference>(doc.value);
if (!ref.requested)
return;

auto key = key_index_.Get(ref.doc_id);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) // Item must have expired
continue;
}

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);

float score = search_results.knn_distances.empty() ? 0 : search_results.knn_distances[i];
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), score});
doc.value = DocResult::SerializedValue{string{key}, std::move(doc_data)};
}

return SearchResult{std::move(out), search_results.ids.size() - expired_count,
std::move(search_results.profile)};
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down
55 changes: 38 additions & 17 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,36 @@ using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*v
std::optional<search::SchemaField::FieldType> ParseSearchFieldType(std::string_view name);
std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);

struct SerializedSearchDoc {
std::string key;
SearchDocData values;
float knn_distance;
struct DocResult {
struct SerializedValue {
std::string key;
SearchDocData values;
};

struct DocReference {
ShardId shard_id;
search::DocId doc_id;
bool requested;
};

std::variant<SerializedValue, DocReference> value;
float score;
};

struct SearchResult {
SearchResult() = default;
size_t write_epoch = 0; // Write epoch of the index during on the result was created

SearchResult(std::vector<SerializedSearchDoc> docs, size_t total_hits,
std::optional<search::AlgorithmProfile> profile)
: docs{std::move(docs)}, total_hits{total_hits}, profile{std::move(profile)} {
}
size_t total_hits = 0; // total number of hits in shard
std::vector<DocResult> docs; // serialized documents of first hits

SearchResult(facade::ErrorReply error) : error{std::move(error)} {
}
// After combining results from multiple shards and accumulating more documents than initially
// requested, only a subset of all documents will be sent back to the client,
// so it doesn't make sense to serialize strictly all documents in every shard ahead.
// Instead, only documents up to a probablistic bound are serialized, the
// leftover ids and scores are stored in the cutoff tail for use in the "unlikely" scenario.
// size_t num_cutoff = 0;

std::vector<SerializedSearchDoc> docs;
size_t total_hits;
std::optional<search::AlgorithmProfile> profile;

std::optional<facade::ErrorReply> error;
};

struct SearchParams {
Expand All @@ -56,6 +64,10 @@ struct SearchParams {
size_t limit_offset;
size_t limit_total;

// Total number of shards, used in probabilistic queries
size_t num_shards;
bool enable_cutoff;

// Set but empty means no fields should be returned
std::optional<FieldReturnList> return_fields;
search::QueryParams query_params;
Expand Down Expand Up @@ -112,8 +124,12 @@ class ShardDocIndex {
ShardDocIndex(std::shared_ptr<DocIndex> index);

// Perform search on all indexed documents and return results.
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const;
io::Result<SearchResult, facade::ErrorReply> Search(const OpArgs& op_args,
const SearchParams& params,
search::SearchAlgorithm* search_algo) const;

bool Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const;

// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args);
Expand All @@ -126,8 +142,13 @@ class ShardDocIndex {

DocIndexInfo GetInfo() const;

private:
void Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const;

private:
std::shared_ptr<const DocIndex> base_;
size_t write_epoch_;
search::FieldIndices indices_;
DocKeyIndex key_index_;
};
Expand Down
Loading

0 comments on commit 24b2025

Please sign in to comment.