Skip to content

Commit

Permalink
ignoring zero diagonals in diagonal_ct_vector_matmul (#405)
Browse files Browse the repository at this point in the history
* ignoring zero diagonals in diagonal_ct_vector_matmul

* applied linter

* added comment

* applied clang-format

* fixed test for numpy version 1.23

* ignore default make flags if MAKEFLAGS is set
  • Loading branch information
mrader1248 authored Aug 7, 2022
1 parent ec21905 commit 539e0b9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 30 deletions.
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ def build_extension(self, ext):
cfg = "Debug" if self.debug else "Release"
build_args = ["--config", cfg]

env = os.environ.copy()

if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2 ** 32:
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
build_args += ["--", "/m", "/p:TrackFileAccess=false"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ["--", "-j"]
if "MAKEFLAGS" not in env:
build_args += ["--", "-j"]

env = os.environ.copy()
env["CXXFLAGS"] = '{} -DVERSION_INFO=\\"{}\\"'.format(
env.get("CXXFLAGS", ""), self.distribution.get_version()
)
Expand Down
48 changes: 28 additions & 20 deletions tenseal/cpp/tensors/encrypted_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,28 +186,36 @@ class EncryptedVector : public EncryptedTensor<plain_t, encrypted_t> {
auto diag = matrix.get_diagonal(
-local_i,
this->tenseal_context()->template slot_count<encoder_t>());
replicate_vector(
diag,
this->tenseal_context()->template slot_count<encoder_t>());

rotate(diag.begin(), diag.begin() + diag.size() - local_i,
diag.end());

this->tenseal_context()->template encode<encoder_t>(diag,
pt_diag);

if (this->_ciphertexts[0].parms_id() != pt_diag.parms_id()) {
this->set_to_same_mod(pt_diag, _ciphertexts[0]);
// don't add zero diagonals to (a) improve performance and (b)
// avoid transparent ciphertext issues
bool is_diag_nonzero = std::any_of(
diag.begin(), diag.end(), [](plain_t x) { return x != 0; });
if (is_diag_nonzero) {
replicate_vector(diag,
this->tenseal_context()
->template slot_count<encoder_t>());

rotate(diag.begin(), diag.begin() + diag.size() - local_i,
diag.end());

this->tenseal_context()->template encode<encoder_t>(
diag, pt_diag);

if (this->_ciphertexts[0].parms_id() !=
pt_diag.parms_id()) {
this->set_to_same_mod(pt_diag, _ciphertexts[0]);
}
this->tenseal_context()->evaluator->multiply_plain(
this->_ciphertexts[0], pt_diag, ct);

this->tenseal_context()->evaluator->rotate_vector_inplace(
ct, local_i, *this->tenseal_context()->galois_keys());

// accumulate thread results
this->tenseal_context()->evaluator->add_inplace(
thread_result, ct);
}
this->tenseal_context()->evaluator->multiply_plain(
this->_ciphertexts[0], pt_diag, ct);

this->tenseal_context()->evaluator->rotate_vector_inplace(
ct, local_i, *this->tenseal_context()->galois_keys());

// accumulate thread results
this->tenseal_context()->evaluator->add_inplace(thread_result,
ct);
}
return thread_result;
};
Expand Down
34 changes: 27 additions & 7 deletions tests/python/tenseal/tensors/test_ckks_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,39 @@ def test_reshape_batching(context, data, new_shape):
@pytest.mark.parametrize(
"data, slices, new_shape",
[
([0, 1, 2, 3, 4, 5], [slice(1, 4, None)], [3]),
([0, 1, 2, 3, 4, 5], [slice(1, None, None)], [5]),
([0, 1, 2, 3, 4, 5], [slice(None, 4, None)], [4]),
([[0, 1, 2], [0, 1, 2], [0, 1, 2]], [slice(1, 3, None), slice(0, 2, None)], [2, 2]),
([[0, 1, 2], [0, 1, 2], [0, 1, 2]], [slice(1, None, None), slice(0, 2, None)], [2, 2]),
([0, 1, 2, 3, 4, 5], (slice(1, 4, None),), [3]),
([0, 1, 2, 3, 4, 5], (slice(1, None, None),), [5]),
([0, 1, 2, 3, 4, 5], (slice(None, 4, None),), [4]),
(
[[0, 1, 2], [0, 1, 2], [0, 1, 2]],
[slice(1, None, None), slice(None, None, None)],
(
slice(1, 3, None),
slice(0, 2, None),
),
[2, 2],
),
(
[[0, 1, 2], [0, 1, 2], [0, 1, 2]],
(
slice(1, None, None),
slice(0, 2, None),
),
[2, 2],
),
(
[[0, 1, 2], [0, 1, 2], [0, 1, 2]],
(
slice(1, None, None),
slice(None, None, None),
),
[2, 3],
),
(
[[0, 1, 2], [0, 1, 2], [0, 1, 2]],
[slice(None, None, None), slice(None, None, None)],
(
slice(None, None, None),
slice(None, None, None),
),
[3, 3],
),
([[0, 1, 2], [0, 1, 2], [0, 1, 2]], 1, [1, 3]),
Expand Down

0 comments on commit 539e0b9

Please sign in to comment.