Skip to content

Commit

Permalink
Use multithreading for BFV and CKKS sum inplace (#393)
Browse files Browse the repository at this point in the history
Co-authored-by: Bogdan Cebere <[email protected]>
  • Loading branch information
vdasu and bcebere authored Apr 29, 2022
1 parent 14698ba commit ff1793f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 12 deletions.
20 changes: 14 additions & 6 deletions tenseal/cpp/tensors/bfvvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,20 @@ shared_ptr<BFVVector> BFVVector::dot_plain_inplace(

shared_ptr<BFVVector> BFVVector::sum_inplace(size_t /*axis=0*/) {
vector<Ciphertext> interm_sum;
// TODO use multithreading for sum
for (size_t idx = 0; idx < this->_ciphertexts.size(); ++idx) {
Ciphertext out = this->_ciphertexts[idx];
sum_vector(this->tenseal_context(), out, this->_sizes[idx]);
interm_sum.push_back(out);
}
size_t size = this->_ciphertexts.size();
interm_sum.resize(size);

task_t worker_func = [&](size_t start, size_t end) -> bool {
for (size_t idx = start; idx < end; ++idx) {
Ciphertext out = this->_ciphertexts[idx];
sum_vector(this->tenseal_context(), out, this->_sizes[idx]);
interm_sum[idx] = out;
}
return true;
};

this->dispatch_jobs(worker_func, size);

Ciphertext result;
tenseal_context()->evaluator->add_many(interm_sum, result);

Expand Down
20 changes: 14 additions & 6 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,20 @@ shared_ptr<CKKSVector> CKKSVector::dot_plain_inplace(const plain_t& to_mul) {

shared_ptr<CKKSVector> CKKSVector::sum_inplace(size_t /*axis = 0*/) {
vector<Ciphertext> interm_sum;
// TODO use multithreading for the sum
for (size_t idx = 0; idx < this->_ciphertexts.size(); ++idx) {
Ciphertext out = this->_ciphertexts[idx];
sum_vector(this->tenseal_context(), out, this->_sizes[idx]);
interm_sum.push_back(out);
}
size_t size = this->_ciphertexts.size();
interm_sum.resize(size);

task_t worker_func = [&](size_t start, size_t end) -> bool {
for (size_t idx = start; idx < end; ++idx) {
Ciphertext out = this->_ciphertexts[idx];
sum_vector(this->tenseal_context(), out, this->_sizes[idx]);
interm_sum[idx] = out;
}
return true;
};

this->dispatch_jobs(worker_func, size);

Ciphertext result;
tenseal_context()->evaluator->add_many(interm_sum, result);

Expand Down
34 changes: 34 additions & 0 deletions tenseal/cpp/tensors/encrypted_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace tenseal {
using namespace seal;
using namespace std;

using task_t = std::function<bool(size_t, size_t)>;

/**
* EncryptedVector<plain_t> interface - Specializes EncryptedTensor interface
*for vectors.
Expand Down Expand Up @@ -246,6 +248,38 @@ class EncryptedVector : public EncryptedTensor<plain_t, encrypted_t> {
protected:
std::vector<size_t> _sizes;
std::vector<Ciphertext> _ciphertexts;

void dispatch_jobs(task_t& worker_func, size_t total_tasks) {
size_t n_jobs =
std::min(total_tasks, this->tenseal_context()->dispatcher_size());

if (n_jobs == 1) {
worker_func(0, total_tasks);
return;
}

size_t batch_size = (total_tasks + n_jobs - 1) / n_jobs;
vector<future<bool>> futures;
for (size_t i = 0; i < n_jobs; i++) {
futures.push_back(
this->tenseal_context()->dispatcher()->enqueue_task(
worker_func, i * batch_size,
std::min((i + 1) * batch_size, total_tasks)));
}

std::optional<std::string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e.what();
}
}

if (fail) {
throw invalid_argument(fail.value());
}
}
};

} // namespace tenseal
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/tensors/ckksvector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,25 @@ TEST_P(CKKSVectorTest, TestCKKSMulNoRelin) {
ASSERT_TRUE(are_close(decr.data(), {4, 8, 12}));
}

TEST_P(CKKSVectorTest, TestCKKSSum) {
auto enc_type = get<1>(GetParam());

auto ctx = TenSEALContext::Create(scheme_type::ckks, 8192, -1,
{60, 40, 40, 60}, enc_type);
ASSERT_TRUE(ctx != nullptr);

ctx->generate_galois_keys();
ctx->global_scale(std::pow(2, 40));

auto l = CKKSVector::Create(
ctx, std::vector<double>({1, 2, 3, 4, 5, 6, 7, 8, 9}));

l->sum_inplace();

auto decr = l->decrypt();
ASSERT_TRUE(are_close(decr.data(), {45}));
}

TEST_P(CKKSVectorTest, TestCKKSReplicateFirstSlot) {
auto should_serialize_first = get<0>(GetParam());
auto enc_type = get<1>(GetParam());
Expand Down
Empty file.

0 comments on commit ff1793f

Please sign in to comment.