-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
minimum-sum-of-squared-difference.cpp
41 lines (40 loc) · 1.65 KB
/
minimum-sum-of-squared-difference.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// Time: O(nlogn + nlogr), r is max((abs(i-j) for i, j in itertools.izip(nums1, nums2))
// Space: O(n)
// binary search
class Solution {
public:
long long minSumSquareDiff(vector<int>& nums1, vector<int>& nums2, int k1, int k2) {
vector<int64_t> diffs(size(nums1));
for (int i = 0; i < size(diffs); ++i) {
diffs[i] = abs(nums1[i] - nums2[i]);
}
sort(rbegin(diffs), rend(diffs));
int64_t k = min(static_cast<int64_t>(k1) + k2, accumulate(cbegin(diffs), cend(diffs), static_cast<int64_t>(0)));
const auto& check = [&](int x) {
return accumulate(cbegin(diffs), cend(diffs), 0ll,
[&](const auto& total, const auto& d) {
return total + max(d - x, static_cast<int64_t>(0));
}) <= k;
};
int64_t left = 0, right = diffs[0];
while (left <= right) {
const int mid = left + (right - left) / 2;
if (check(mid)) {
right = mid - 1;
} else {
left = mid + 1;
}
}
k -= accumulate(cbegin(diffs), cend(diffs), 0ll,
[&](const auto& total, const auto& d) {
return total + max(d - left, static_cast<int64_t>(0));
});
for (int i = 0; i < size(diffs); ++i) {
diffs[i] = min(diffs[i], left) - int(i < k);
}
return accumulate(cbegin(diffs), cend(diffs), 0ll,
[](const auto& total, const auto& d) {
return total + d * d;
});
}
};