From e58b196ed41733e6089347572823237d8c4eb6de Mon Sep 17 00:00:00 2001 From: Lakshya747 Date: Wed, 30 Oct 2024 20:13:48 +0530 Subject: [PATCH] Update karatsuba_algorithm_for_fast_multiplication.cpp --- ...suba_algorithm_for_fast_multiplication.cpp | 179 +++++------------- 1 file changed, 51 insertions(+), 128 deletions(-) diff --git a/divide_and_conquer/karatsuba_algorithm_for_fast_multiplication.cpp b/divide_and_conquer/karatsuba_algorithm_for_fast_multiplication.cpp index 1d9bd86cdda..47042a3ebc4 100644 --- a/divide_and_conquer/karatsuba_algorithm_for_fast_multiplication.cpp +++ b/divide_and_conquer/karatsuba_algorithm_for_fast_multiplication.cpp @@ -3,175 +3,97 @@ * @brief Implementation of the [Karatsuba algorithm for fast * multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm) * @details - * Given two strings in binary notation we want to multiply them and return the - * value. Simple approach is to multiply bits one by one which will give the time - * complexity of around O(n^2). To make it more efficient we will be using - * Karatsuba algorithm to find the product which will solve the problem - * O(nlogn) of time. - * @author [Swastika Gupta](https://github.com/Swastyy) - * @author [Ameer Carlo Lubang](https://github.com/poypoyan) + * Given two strings in binary notation, this implementation multiplies them + * using the Karatsuba algorithm, achieving O(n^log2(3)) time complexity. + * @authors [Swastika Gupta](https://github.com/Swastyy), [Ameer Carlo Lubang](https://github.com/poypoyan) */ -#include /// for assert -#include /// for string -#include /// for IO operations -#include /// for std::vector +#include // for assert +#include // for string +#include // for IO operations -/** - * @namespace divide_and_conquer - * @brief Divide and Conquer algorithms - */ namespace divide_and_conquer { -/** - * @namespace karatsuba_algorithm - * @brief Functions for the [Karatsuba algorithm for fast - * multiplication](https://en.wikipedia.org/wiki/Karatsuba_algorithm) implementation - */ namespace karatsuba_algorithm { + /** - * @brief Binary addition + * @brief Binary addition of two binary strings. * @param first, the input string 1 * @param second, the input string 2 - * @returns the sum binary string + * @returns the sum as a binary string */ -std::string add_strings(std::string first, std::string second) { - std::string result; // to store the resulting sum bits - - // make the string lengths equal - int64_t len1 = first.size(); - int64_t len2 = second.size(); - std::string zero = "0"; - if (len1 < len2) { - for (int64_t i = 0; i < len2 - len1; i++) { - zero += first; - first = zero; - zero = "0"; // Prevents CI from failing - } - } else if (len1 > len2) { - for (int64_t i = 0; i < len1 - len2; i++) { - zero += second; - second = zero; - zero = "0"; // Prevents CI from failing - } +std::string add_strings(const std::string &first, const std::string &second) { + std::string result; + int carry = 0, sum = 0; + size_t len1 = first.size(), len2 = second.size(); + size_t max_len = std::max(len1, len2); + + // Add bits from right to left + for (size_t i = 0; i < max_len || carry; ++i) { + if (i < len1) carry += first[len1 - 1 - i] - '0'; + if (i < len2) carry += second[len2 - 1 - i] - '0'; + + result.push_back((carry % 2) + '0'); // Current bit + carry /= 2; // Carry for next bit } - int64_t length = std::max(len1, len2); - int64_t carry = 0; - for (int64_t i = length - 1; i >= 0; i--) { - int64_t firstBit = first.at(i) - '0'; - int64_t secondBit = second.at(i) - '0'; - - int64_t sum = (char(firstBit ^ secondBit ^ carry)) + '0'; // sum of 3 bits - result.insert(result.begin(), sum); - - carry = char((firstBit & secondBit) | (secondBit & carry) | - (firstBit & carry)); // sum of 3 bits - } - - if (carry) { - result.insert(result.begin(), '1'); // adding 1 incase of overflow - } + // The result is in reverse order + std::reverse(result.begin(), result.end()); return result; } /** - * @brief Wrapper function for substr that considers leading zeros. - * @param str, the binary input string. - * @param x1, the substr parameter integer 1 - * @param x2, the substr parameter integer 2 - * @param n, is the length of the "whole" string: leading zeros + str - * @returns the "safe" substring for the algorithm *without* leading zeros - * @returns "0" if substring spans to leading zeros only - */ -std::string safe_substr(const std::string &str, int64_t x1, int64_t x2, int64_t n) { - int64_t len = str.size(); - - if (len >= n) { - return str.substr(x1, x2); - } - - int64_t y1 = x1 - (n - len); // index in str of first char of substring of "whole" string - int64_t y2 = (x1 + x2 - 1) - (n - len); // index in str of last char of substring of "whole" string - - if (y2 < 0) { - return "0"; - } else if (y1 < 0) { - return str.substr(0, y2 + 1); - } else { - return str.substr(y1, x2); - } -} - -/** - * @brief The main function implements Karatsuba's algorithm for fast - * multiplication + * @brief The main function implements Karatsuba's algorithm for fast multiplication. * @param str1 the input string 1 * @param str2 the input string 2 * @returns the product number value */ -int64_t karatsuba_algorithm(std::string str1, std::string str2) { - int64_t len1 = str1.size(); - int64_t len2 = str2.size(); - int64_t n = std::max(len1, len2); - - if (n == 0) { - return 0; +int64_t karatsuba_algorithm(const std::string &str1, const std::string &str2) { + if (str1 == "0" || str2 == "0") return 0; // Handle multiplication by zero + if (str1 == "1") return std::stoll(str2); // Handle multiplication by one + if (str2 == "1") return std::stoll(str1); + + size_t len1 = str1.size(), len2 = str2.size(); + size_t n = std::max(len1, len2); + + // Make lengths even + if (n % 2 != 0) { + str1 = '0' + str1; + str2 = '0' + str2; + n++; } - if (n == 1) { - return (str1[0] - '0') * (str2[0] - '0'); - } - - int64_t fh = n / 2; // first half of string - int64_t sh = n - fh; // second half of string - - std::string Xl = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, 0, fh, n); // first half of first string - std::string Xr = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, fh, sh, n); // second half of first string - std::string Yl = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, 0, fh, n); // first half of second string - std::string Yr = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, fh, sh, n); // second half of second string + size_t fh = n / 2; // Half the size + std::string Xl = str1.substr(0, len1 - fh); + std::string Xr = str1.substr(len1 - fh); + std::string Yl = str2.substr(0, len2 - fh); + std::string Yr = str2.substr(len2 - fh); - // calculating the three products of inputs of size n/2 recursively + // Recursive calls int64_t product1 = karatsuba_algorithm(Xl, Yl); int64_t product2 = karatsuba_algorithm(Xr, Yr); - int64_t product3 = karatsuba_algorithm( - divide_and_conquer::karatsuba_algorithm::add_strings(Xl, Xr), - divide_and_conquer::karatsuba_algorithm::add_strings(Yl, Yr)); + int64_t product3 = karatsuba_algorithm(add_strings(Xl, Xr), add_strings(Yl, Yr)); - return product1 * (1 << (2 * sh)) + - (product3 - product1 - product2) * (1 << sh) + - product2; // combining the three products to get the final result. + return product1 * (1LL << (2 * fh)) + (product3 - product1 - product2) * (1LL << fh) + product2; } -} // namespace karatsuba_algorithm -} // namespace divide_and_conquer /** * @brief Self-test implementations - * @returns void */ static void test() { // 1st test std::string s11 = "1"; // 1 std::string s12 = "1010"; // 10 - std::cout << "1st test... "; - assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( - s11, s12) == 10); - std::cout << "passed" << std::endl; + assert(karatsuba_algorithm(s11, s12) == 10); // 2nd test std::string s21 = "11"; // 3 std::string s22 = "1010"; // 10 - std::cout << "2nd test... "; - assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( - s21, s22) == 30); - std::cout << "passed" << std::endl; + assert(karatsuba_algorithm(s21, s22) == 30); // 3rd test std::string s31 = "110"; // 6 std::string s32 = "1010"; // 10 - std::cout << "3rd test... "; - assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm( - s31, s32) == 60); - std::cout << "passed" << std::endl; + assert(karatsuba_algorithm(s31, s32) == 60); } /** @@ -179,6 +101,7 @@ static void test() { * @returns 0 on exit */ int main() { - test(); // run self-test implementations + test(); // Run self-test implementations + std::cout << "All tests passed!" << std::endl; return 0; }