Skip to content

Commit

Permalink
Move common functionality to basic.py and DigitDecompExtend to decomp…
Browse files Browse the repository at this point in the history
….py (#47)

* Refactor KeyMul and move it to basic.py
* Refactor and move DigitDecompExtend to its own file, decomp.py.
---------

Co-authored-by: Flavio Bergamaschi <[email protected]>
  • Loading branch information
christopherngutierrez and faberga authored Sep 20, 2024
1 parent cec8bb9 commit ae3c257
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 103 deletions.
60 changes: 59 additions & 1 deletion kerngen/pisa_generators/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Copyright (C) 2024 Intel Corporation

"""Module containing conversions or operations from isa to p-isa."""

import itertools as it
from dataclasses import dataclass
from typing import ClassVar, Iterable
from typing import ClassVar, Iterable, Tuple
from string import ascii_letters

import high_parser.pisa_operations as pisa_op
from high_parser.pisa_operations import PIsaOp
Expand Down Expand Up @@ -238,3 +242,57 @@ def to_pisa(self) -> list[PIsaOp]:
pisa_op.Copy(self.context.label, *expand_io)
for expand_io, _ in expand_ios(self.context, self.output, self.input0)
]


@dataclass
class KeyMul(HighOp):
"""Class representing a key multiplication operation"""

context: KernelContext
output: Polys
input0: Polys
input1: KeyPolys
input0_fixed_part: int

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform a key multiplication"""

def get_pisa_op(num):
yield 0, pisa_op.Mul
yield from ((op, pisa_op.Mac) for op in range(1, num))

ls: list[pisa_op] = []
for digit, op in get_pisa_op(self.input1.digits):
input0_tmp = Polys.from_polys(self.input0)
input0_tmp.name += "_" + ascii_letters[digit]
ls.extend(
op(
self.context.label,
self.output(part, q, unit),
input0_tmp(self.input0_fixed_part, q, unit),
self.input1(digit, part, q, unit),
q,
)
for part, q, unit in it.product(
range(self.input1.start_parts, self.input1.parts),
range(self.input0.start_rns, self.input0.rns),
range(self.context.units),
)
)
return ls


def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]:
"""Split and extract the last part of input0 with a change of rns"""
input_last_part = Polys.from_polys(input0, mode="last_part")
input_last_part.name = input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = rns

upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0

return input_last_part, last_coeff, upto_last_coeffs
59 changes: 59 additions & 0 deletions kerngen/pisa_generators/decomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Module containing digit decomposition/base extend"""

from string import ascii_letters
import itertools as it

from dataclasses import dataclass
import high_parser.pisa_operations as pisa_op
from high_parser.pisa_operations import PIsaOp
from high_parser import KernelContext, HighOp, Immediate, Polys

from .basic import Muli, mixed_to_pisa_ops
from .ntt import INTT, NTT


@dataclass
class DigitDecompExtend(HighOp):
"""Class representing Digit decomposition and base extension"""

context: KernelContext
output: Polys
input0: Polys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code performing Digit decomposition followed by
base extension"""

rns_poly = Polys.from_polys(self.input0)
rns_poly.name = "ct"

one = Immediate(name="one")
r2 = Immediate(name="R2", rns=self.context.key_rns)

ls: list[pisa_op] = []
for input_rns_index in range(self.input0.rns):
ls.extend(
pisa_op.Muli(
self.context.label,
self.output(part, pq, unit),
rns_poly(part, input_rns_index, unit),
r2(part, pq, unit),
pq,
)
for part, pq, unit in it.product(
range(self.input0.start_parts, self.input0.parts),
range(self.context.key_rns),
range(self.context.units),
)
)
output_tmp = Polys.from_polys(self.output)
output_tmp.name += "_" + ascii_letters[input_rns_index]
ls.extend(NTT(self.context, output_tmp, self.output).to_pisa())

return mixed_to_pisa_ops(
INTT(self.context, rns_poly, self.input0),
Muli(self.context, rns_poly, rns_poly, one),
ls,
)
110 changes: 8 additions & 102 deletions kerngen/pisa_generators/relin.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module containing relin, keymul, etc."""

"""Module containing relin."""
from dataclasses import dataclass
from itertools import product
from string import ascii_letters

import high_parser.pisa_operations as pisa_op
from high_parser.pisa_operations import PIsaOp, Comment
from high_parser import KernelContext, HighOp, Immediate, KeyPolys, Polys

from .basic import Add, Muli, mixed_to_pisa_ops
from high_parser import KernelContext, HighOp, KeyPolys, Polys
from .basic import Add, KeyMul, mixed_to_pisa_ops, extract_last_part_polys
from .mod import Mod
from .ntt import INTT, NTT


@dataclass
class KeyMul(HighOp):
"""Class representing a key multiplication operation"""

context: KernelContext
output: Polys
input0: Polys
input1: KeyPolys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform a key multiplication"""

def get_pisa_op(num):
yield 0, pisa_op.Mul
yield from ((op, pisa_op.Mac) for op in range(1, num))

ls: list[pisa_op] = []
for digit, op in get_pisa_op(self.input1.digits):
input0_tmp = Polys.from_polys(self.input0)
input0_tmp.name += "_" + ascii_letters[digit]
ls.extend(
op(
self.context.label,
self.output(part, q, unit),
input0_tmp(2, q, unit),
self.input1(digit, part, q, unit),
q,
)
for part, q, unit in product(
range(self.input1.start_parts, self.input1.parts),
range(self.input0.start_rns, self.input0.rns),
range(self.context.units),
)
)
return ls


@dataclass
class DigitDecompExtend(HighOp):
"""Class representing Digit decomposition and base extension"""

context: KernelContext
output: Polys
input0: Polys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code performing Digit decomposition followed by
base extension"""

rns_poly = Polys.from_polys(self.input0)
rns_poly.name = "ct"

one = Immediate(name="one")
r2 = Immediate(name="R2", rns=self.context.key_rns)

ls: list[pisa_op] = []
for input_rns_index in range(self.input0.rns):
ls.extend(
pisa_op.Muli(
self.context.label,
self.output(part, pq, unit),
rns_poly(part, input_rns_index, unit),
r2(part, pq, unit),
pq,
)
for part, pq, unit in product(
range(self.input0.start_parts, self.input0.parts),
range(self.context.key_rns),
range(self.context.units),
)
)
output_tmp = Polys.from_polys(self.output)
output_tmp.name += "_" + ascii_letters[input_rns_index]
ls.extend(NTT(self.context, output_tmp, self.output).to_pisa())

return mixed_to_pisa_ops(
INTT(self.context, rns_poly, self.input0),
Muli(self.context, rns_poly, rns_poly, one),
ls,
)
from .decomp import DigitDecompExtend


@dataclass
Expand All @@ -120,15 +32,9 @@ def to_pisa(self) -> list[PIsaOp]:
mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns)
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns
input_last_part = Polys.from_polys(self.input0, mode="last_part")
input_last_part.name = self.input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = self.context.key_rns
upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0
input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys(
self.input0, self.context.key_rns
)

add_original = Polys.from_polys(mul_by_rlk_modded_down)
add_original.name = self.input0.name
Expand All @@ -138,7 +44,7 @@ def to_pisa(self) -> list[PIsaOp]:
Comment("Digit decomposition and extend base from Q to PQ"),
DigitDecompExtend(self.context, last_coeff, input_last_part),
Comment("Multiply by relin key"),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 2),
Comment("Mod switch down to Q"),
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk),
Comment("Add to original poly"),
Expand Down

0 comments on commit ae3c257

Please sign in to comment.